Compare commits

..

1 Commits

Author SHA1 Message Date
pablodanswer
14f57d6475 remove endpoint 2024-10-31 12:07:47 -07:00
485 changed files with 8487 additions and 20851 deletions

View File

@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }}
REGISTRY_IMAGE: danswer/danswer-backend
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: investigate a matrix build like the web container
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
trivyignores: ./backend/.trivyignore
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
trivyignores: ./backend/.trivyignore

View File

@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-web-server-cloud
REGISTRY_IMAGE: danswer/danswer-cloud-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
@@ -28,11 +28,11 @@ jobs:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
@@ -41,16 +41,16 @@ jobs:
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
@@ -65,17 +65,17 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
# needed due to weird interactions with the builds for different platforms
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
@@ -95,42 +95,42 @@ jobs:
path: /tmp/digests
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
severity: 'CRITICAL,HIGH'

View File

@@ -3,53 +3,53 @@ name: Build and Push Model Server Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }}
REGISTRY_IMAGE: danswer/danswer-model-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -1,76 +0,0 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
security-events: write
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@v2
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: ${{ always() }}
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# with:
# sarif_file: trivy-results.sarif

View File

@@ -210,18 +210,17 @@ jobs:
echo "All integration tests passed successfully."
fi
# save before stopping the containers so the logs can be captured
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()

View File

@@ -1,20 +1,24 @@
# This workflow is intentionally disabled while we're still working on it
# It's close to ready, but a race condition needs to be fixed with
# API server and Vespa startup, and it needs to have a way to build/test against
# local containers
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
workflow_dispatch: # Allows manual triggering
jobs:
helm-chart-check:
lint-test:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
with:
fetch-depth: 0
@@ -24,7 +28,7 @@ jobs:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
@@ -41,31 +45,24 @@ jobs:
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
run: |
echo "default_branch: ${{ github.event.repository.default_branch }}"
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
run: ct lint --config ct.yaml --all
# the following would lint only changed charts, but linting isn't expensive
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
# if: steps.list-changed.outputs.changed == 'true'
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
# if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
# if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --config ct.yaml
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -18,11 +18,6 @@ env:
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
jobs:
connectors-check:

View File

@@ -15,7 +15,7 @@ env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
model-check:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]

View File

@@ -1,5 +1,4 @@
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
@@ -128,19 +127,3 @@ To try the Danswer Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=danswer-ai/danswer&type=Date)](https://star-history.com/#danswer-ai/danswer&Date)
## ✨Contributors
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>

View File

@@ -12,6 +12,7 @@ ARG DANSWER_VERSION=0.8-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ARG CA_CERT_CONTENT=""
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
@@ -38,6 +39,15 @@ RUN apt-get update && \
apt-get clean
# Conditionally write the CA certificate and update certificates
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
echo "Adding custom CA certificate"; \
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
update-ca-certificates; \
else \
echo "No custom CA certificate provided"; \
fi
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
@@ -77,6 +87,7 @@ RUN apt-get update && \
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \

View File

@@ -1,50 +0,0 @@
"""single tool call per message
Revision ID: 33cb72ea4d80
Revises: 5b29123cd710
Create Date: 2024-11-01 12:51:01.535003
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "33cb72ea4d80"
down_revision = "5b29123cd710"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Step 1: Delete extraneous ToolCall entries
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
op.execute(
sa.text(
"""
DELETE FROM tool_call
WHERE id NOT IN (
SELECT MIN(id)
FROM tool_call
WHERE message_id IS NOT NULL
GROUP BY message_id
);
"""
)
)
# Step 2: Add a unique constraint on message_id
op.create_unique_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
columns=["message_id"],
)
def downgrade() -> None:
# Step 1: Drop the unique constraint on message_id
op.drop_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
type_="unique",
)

View File

@@ -1,70 +0,0 @@
"""nullable search settings for historic index attempts
Revision ID: 5b29123cd710
Revises: 949b4a92a401
Create Date: 2024-10-30 19:37:59.630704
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5b29123cd710"
down_revision = "949b4a92a401"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be nullable
op.alter_column(
"index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True
)
# Add back the foreign key with ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
# Warning: This will delete all index attempts that don't have search settings
op.execute(
"""
DELETE FROM index_attempt
WHERE search_settings_id IS NULL
"""
)
# Drop foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be not nullable
op.alter_column(
"index_attempt",
"search_settings_id",
existing_type=sa.INTEGER(),
nullable=False,
)
# Add back the foreign key without ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
)

View File

@@ -288,15 +288,6 @@ def upgrade() -> None:
def downgrade() -> None:
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
# below
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM chat_message__search_doc")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM chat_message")
op.execute("DELETE FROM chat_session")
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)

View File

@@ -1,48 +0,0 @@
"""remove description from starter messages
Revision ID: b72ed7a5db0e
Revises: 33cb72ea4d80
Create Date: 2024-11-03 15:55:28.944408
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b72ed7a5db0e"
down_revision = "33cb72ea4d80"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
UPDATE persona
SET starter_messages = (
SELECT jsonb_agg(elem - 'description')
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
UPDATE persona
SET starter_messages = (
SELECT jsonb_agg(elem || '{"description": ""}')
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)

View File

@@ -1,29 +0,0 @@
"""add recent assistants
Revision ID: c0fd6e4da83a
Revises: b72ed7a5db0e
Create Date: 2024-11-03 17:28:54.916618
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c0fd6e4da83a"
down_revision = "b72ed7a5db0e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)
def downgrade() -> None:
op.drop_column("user", "recent_assistants")

View File

@@ -23,56 +23,6 @@ def upgrade() -> None:
def downgrade() -> None:
# Delete chat messages and feedback first since they reference chat sessions
# Get chat messages from sessions with null persona_id
chat_messages_query = """
SELECT id
FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
# Delete dependent records first
op.execute(
f"""
DELETE FROM document_retrieval_feedback
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
op.execute(
f"""
DELETE FROM chat_message__search_doc
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
# Delete chat messages
op.execute(
"""
DELETE FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
)
# Now we can safely delete the chat sessions
op.execute(
"""
DELETE FROM chat_session
WHERE persona_id IS NULL
"""
)
op.alter_column(
"chat_session",
"persona_id",

View File

@@ -48,11 +48,11 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
@@ -75,7 +75,6 @@ from danswer.configs.constants import AuthType
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
@@ -84,27 +83,24 @@ from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -194,6 +190,20 @@ def verify_email_domain(email: str) -> None:
)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def send_user_verification_email(
user_email: str,
token: str,
@@ -228,13 +238,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
)
try:
tenant_id = (
get_tenant_id_for_email(user_create.email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -255,7 +271,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
@@ -276,9 +292,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else:
raise exceptions.UserAlreadyExists()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def oauth_callback(
@@ -294,18 +308,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
)
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
# Proceed with the tenant context
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -356,9 +371,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
await self.on_after_register(user, request)
else:
@@ -438,13 +453,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
email = credentials.username
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=email,
)
tenant_id = get_tenant_id_for_email(email)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -468,7 +477,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -500,30 +510,19 @@ cookie_transport = CookieTransport(
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user.email,
)
async def write_token(self, user: User) -> str:
tenant_id = get_tenant_id_for_email(user.email)
data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return data
async def write_token(self, user: User) -> str:
data = await self._create_token_data(user)
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
)
def get_jwt_strategy() -> TenantAwareJWTStrategy:
def get_jwt_strategy() -> JWTStrategy:
return TenantAwareJWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
@@ -625,12 +624,14 @@ async def double_check_user(
return None
if user is None:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated.",
)
if user_needs_to_be_verified() and not user.is_verified:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not verified.",
)
@@ -639,7 +640,8 @@ async def double_check_user(
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
)
@@ -665,13 +667,15 @@ async def current_curator_or_admin_user(
return None
if not user or not hasattr(user, "role"):
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated or lacks role information.",
)
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
if user.role not in allowed_roles:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not a curator or admin.",
)
@@ -683,7 +687,8 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
return None
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User must be an admin to perform this action.",
)
@@ -876,22 +881,3 @@ def get_oauth_router(
return redirect_response
return router
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return user

View File

@@ -3,7 +3,6 @@ import multiprocessing
import time
from typing import Any
import requests
import sentry_sdk
from celery import Task
from celery.app import trace
@@ -12,22 +11,18 @@ from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.db.engine import get_all_tenant_ids
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
@@ -113,27 +108,29 @@ def on_task_postrun(
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(tenant_id, int(document_set_id))
rds = RedisDocumentSet(int(document_set_id))
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(tenant_id, int(usergroup_id))
rug = RedisUserGroup(int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDelete.PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r)
rcd = RedisConnectorDeletion(int(cc_pair_id))
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
rcp = RedisConnectorPruning(int(cc_pair_id))
r.srem(rcp.taskset_key, task_id)
return
@@ -143,154 +140,27 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
r = get_redis_client(tenant_id=None)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Redis: Readiness probe starting.")
logger.info("Redis: Readiness check starting.")
while True:
try:
if r.ping():
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Redis: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Redis: Readiness probe succeeded. Continuing...")
return
def wait_for_db(sender: Any, **kwargs: Any) -> None:
"""Waits for the db to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Database: Readiness probe starting.")
while True:
try:
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(text("SELECT NOW()")).scalar()
if result:
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Database: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Database: Readiness probe succeeded. Continuing...")
return
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Vespa: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_redis_client(tenant_id=None)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")
while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time_elapsed = time.monotonic() - time_start
logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
@@ -298,7 +168,57 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
logger.info("Redis: Readiness check succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
logger.info("Running as a secondary celery worker.")
logger.info("Waiting for all tenant primary workers to be ready...")
time_start = time.monotonic()
while True:
tenant_ids = get_all_tenant_ids()
# Check if we have a primary worker lock for each tenant
all_tenants_ready = all(
get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
for tenant_id in tenant_ids
)
if all_tenants_ready:
break
time_elapsed = time.monotonic() - time_start
ready_tenants = sum(
1
for tenant_id in tenant_ids
if get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
)
logger.info(
f"Not all tenant primary workers are ready yet. "
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Not all tenant primary workers were ready within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("All tenant primary workers are ready. Continuing...")
return
@@ -310,20 +230,26 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not sender.primary_worker_lock:
if not hasattr(sender, "primary_worker_locks"):
return
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
try:
if lock.owned():
try:
lock.release()
sender.primary_worker_lock = None
except Exception as e:
logger.error(f"Failed to release primary worker lock: {e}")
except Exception as e:
logger.error(f"Failed to check if primary worker lock is owned: {e}")
for tenant_id, lock in sender.primary_worker_locks.items():
try:
if lock and lock.owned():
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
try:
lock.release()
logger.debug(f"Successfully released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
)
finally:
sender.primary_worker_locks[tenant_id] = None
except Exception as e:
logger.error(
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
)
def on_setup_logging(

View File

@@ -3,152 +3,28 @@ from typing import Any
from celery import Celery
from celery import signals
from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.beat")
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.info("Initializing DynamicTenantScheduler")
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=2)
self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization
self.setup_schedule()
self._update_tenant_tasks()
logger.info(f"Set reload interval to {self._reload_interval}")
def setup_schedule(self) -> None:
logger.info("Setting up initial schedule")
super().setup_schedule()
logger.info("Initial schedule setup complete")
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating tenant task update")
self._update_tenant_tasks()
self._last_reload = now
logger.info("Tenant task update completed, reset reload timer")
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Starting tenant task update process")
try:
logger.info("Fetching all tenant IDs")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} tenants")
logger.info("Fetching tasks to schedule")
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = self.schedule.items()
existing_tenants = set()
for task_name, _ in current_schedule:
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
for tenant_id in tenant_ids:
if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {task_name}")
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
# Create schedule entries
entries = {}
for name, entry in new_beat_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
)
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
# Ensure changes are persisted
self.sync()
logger.info("Schedule update completed successfully")
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
logger.debug("Comparing current and new schedules")
current_tasks = set(name for name, _ in current_schedule)
new_tasks = set(new_schedule.keys())
needs_update = current_tasks != new_tasks
logger.debug(f"Schedule update needed: {needs_update}")
return needs_update
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
logger.info("beat_init signal received.")
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
# celery beat shouldn't touch the db at all. But just setting a low minimum here.
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
@@ -159,4 +35,68 @@ def on_setup_logging(
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.conf.beat_scheduler = DynamicTenantScheduler
#####
# Celery Beat (Periodic Tasks) Settings
#####
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for tenant_id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration once
celery_app.conf.beat_schedule = beat_schedule

View File

@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -61,13 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -61,13 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -60,13 +59,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@@ -91,6 +85,5 @@ celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
"danswer.background.celery.tasks.connector_deletion",
]
)

View File

@@ -13,21 +13,21 @@ from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -75,64 +75,95 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
logger.info("Running as the primary celery worker.")
sender.primary_worker_locks = {}
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
# by the primary worker
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
# tacking on our own user data to the sender
sender.primary_worker_lock = lock
# tacking on our own user data to the sender
sender.primary_worker_locks[tenant_id] = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
RedisDocumentSet.reset_all(r)
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
RedisUserGroup.reset_all(r)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
RedisConnectorDelete.reset_all(r)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
RedisConnectorPrune.reset_all(r)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
RedisConnectorIndex.reset_all(r)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
RedisConnectorStop.reset_all(r)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)
@worker_ready.connect
@@ -185,36 +216,52 @@ class HubPeriodicTask(bootsteps.StartStopStep):
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_lock"):
if not hasattr(worker, "primary_worker_locks"):
return
lock = worker.primary_worker_lock
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
r = get_redis_client(tenant_id=None)
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
r = get_redis_client(tenant_id=tenant_id)
task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
worker.primary_worker_lock = lock
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
else:
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception:
task_logger.exception("Periodic task failed.")

View File

@@ -1,96 +0,0 @@
from datetime import timedelta
from typing import Any
from celery.beat import PersistentScheduler # type: ignore
from celery.utils.log import get_task_logger
from danswer.db.engine import get_all_tenant_ids
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = get_task_logger(__name__)
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=1)
self._last_reload = self.app.now() - self._reload_interval
def setup_schedule(self) -> None:
super().setup_schedule()
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reloading schedule to check for new tenants...")
self._update_tenant_tasks()
self._last_reload = now
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Checking for tenant task updates...")
try:
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = getattr(self, "_store", {"entries": {}}).get(
"entries", {}
)
existing_tenants = set()
for task_name in current_schedule.keys():
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
for tenant_id in tenant_ids:
if tenant_id not in existing_tenants:
logger.info(f"Found new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Updating schedule",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
if not hasattr(self, "_store"):
self._store: dict[str, dict] = {"entries": {}}
self.update_from_dict(new_beat_schedule)
logger.info(f"New schedule: {new_beat_schedule}")
logger.info("Tenant tasks updated successfully")
else:
logger.debug("No schedule updates needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
current_tasks = set(current_schedule.keys())
new_tasks = set(new_schedule.keys())
return current_tasks != new_tasks

View File

@@ -1,10 +1,568 @@
# These are helper objects for tracking the keys we need to write in redis
import time
from abc import ABC
from abc import abstractmethod
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.configs.base import CELERY_SEPARATOR
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: str):
self._id: str = id
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
object_id = parts[1]
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
pass
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisUserGroup(RedisObjectHelper):
PREFIX = "usergroup"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
if not global_version.is_ee_version():
return 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class is used to scan documents by cc_pair in the db and collect them into
a unified set for syncing.
It differs from the other redis helpers in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorDeletion(RedisObjectHelper):
PREFIX = "connectordeletion"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorPruning(RedisObjectHelper):
"""Celery will kick off a long running generator task to crawl the connector and
find any missing docs, which will each then get a new cleanup task. The progress of
those tasks will then be monitored to completion.
Example rough happy path order:
Check connectorpruning_fence_1
Send generator task with id connectorpruning+generator_1_{uuid}
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
in taskset connectorpruning_taskset_1
on completion, generator sets connectorpruning_generator_complete_1
celery postrun removes subtasks from taskset
monitor beat task cleans up when taskset reaches 0 items
"""
PREFIX = "connectorpruning"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
super().__init__(str(id))
self.documents_to_prune: set[str] = set()
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
for doc_id in self.documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def is_pruning(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorIndexing(RedisObjectHelper):
"""Celery will kick off a long running indexing task to crawl the connector and
find any new or updated docs docs, which will each then get a new sync task or be
indexed inline.
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
e.g. "2/5"
"""
PREFIX = "connectorindexing"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
super().__init__(f"{cc_pair_id}/{search_settings_id}")
@property
def generator_lock_key(self) -> str:
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def is_indexing(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorStop(RedisObjectHelper):
"""Used to signal any running tasks for a connector to stop. We should refactor
connector related redis helpers into a single class.
"""
PREFIX = "connectorstop"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int:

View File

@@ -4,6 +4,7 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
@@ -17,7 +18,7 @@ from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@@ -40,14 +41,14 @@ def _get_deletion_status(
if not cc_pair:
return None
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client(tenant_id=tenant_id)
if not r.exists(rcd.fence_key):
return None
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
)

View File

@@ -1,48 +0,0 @@
from datetime import timedelta
from typing import Any
from danswer.configs.constants import DanswerCeleryPriority
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return tasks_to_schedule

View File

@@ -10,6 +10,13 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
@@ -18,8 +25,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
from danswer.redis.redis_pool import get_redis_client
@@ -57,7 +62,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcs = RedisConnectorStop(cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
@@ -66,10 +71,10 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
# this means we wanted to start deleting but dependent tasks were running
# Leave a stop signal to clear indexing and pruning tasks more quickly
task_logger.info(str(e))
redis_connector.stop.set_fence(True)
r.set(rcs.fence_key, cc_pair_id)
else:
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)
r.delete(rcs.fence_key)
except SoftTimeLimitExceeded:
task_logger.info(
@@ -101,10 +106,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcd = RedisConnectorDeletion(cc_pair_id)
# don't generate sync tasks if tasks are still pending
if redis_connector.delete.fenced:
if r.exists(rcd.fence_key):
return None
# we need to load the state of the object inside the fence
@@ -118,49 +123,47 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
fence_payload = RedisConnectorDeletionFenceData(
fence_value = RedisConnectorDeletionFenceData(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)
redis_connector.delete.set_fence(fence_payload)
r.set(rcd.fence_key, fence_value.model_dump_json())
try:
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
if r.get(rci.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings.id}"
)
if redis_connector.prune.fenced:
rcp = RedisConnectorPruning(cc_pair_id)
if r.get(rcp.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): "
f"cc_pair={cc_pair_id}"
)
# add tasks to celery and build up the task set to monitor in redis
redis_connector.delete.taskset_clear()
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.delete.generate_tasks(
app, db_session, lock_beat
)
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
except TaskDependencyError:
redis_connector.delete.set_fence(None)
r.delete(rcd.fence_key)
raise
except Exception:
task_logger.exception("Unexpected exception")
redis_connector.delete.set_fence(None)
r.delete(rcd.fence_key)
return None
else:
# Currently we are allowing the sync to proceed with 0 tasks.
@@ -175,7 +178,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
)
# set this only after all tasks have been added
fence_payload.num_tasks = tasks_generated
redis_connector.delete.set_fence(fence_payload)
fence_value.num_tasks = tasks_generated
r.set(rcd.fence_key, fence_value.model_dump_json())
return tasks_generated

View File

@@ -2,9 +2,10 @@ from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import cast
from uuid import uuid4
import redis
import sentry_sdk
from celery import Celery
from celery import shared_task
from celery import Task
@@ -13,6 +14,12 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
@@ -43,15 +50,12 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
@@ -101,22 +105,19 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
return None
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if current_search_settings.provider_type is None and not MULTI_TENANT:
if old_search_settings:
embedding_model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
)
# only warm up if search settings were changed
warm_up_bi_encoder(
embedding_model=embedding_model,
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
@@ -125,7 +126,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
cc_pair_ids.append(cc_pair_entry.id)
for cc_pair_id in cc_pair_ids:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
@@ -138,10 +138,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
rci = RedisConnectorIndexing(
cc_pair_id, search_settings_instance.id
)
if redis_connector_index.fenced:
if r.exists(rci.fence_key):
continue
cc_pair = get_connector_credential_pair_from_id(
@@ -175,9 +175,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
if attempt_id:
task_logger.info(
f"Indexing queued: index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"Indexing queued: cc_pair={cc_pair.id} index_attempt={attempt_id}"
)
tasks_created += 1
except SoftTimeLimitExceeded:
@@ -306,15 +304,15 @@ def try_creating_indexing_task(
return None
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
# skip if already indexing
if redis_connector_index.fenced:
if r.exists(rci.fence_key):
return None
# skip indexing if the cc_pair is deleting
if redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
@@ -322,17 +320,19 @@ def try_creating_indexing_task(
return None
# add a long running generator task to the queue
redis_connector_index.generator_clear()
r.delete(rci.generator_complete_key)
r.delete(rci.taskset_key)
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
# set a basic fence to start
payload = RedisConnectorIndexingFenceData(
fence_value = RedisConnectorIndexingFenceData(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
redis_connector_index.set_fence(payload)
r.set(rci.fence_key, fence_value.model_dump_json())
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
@@ -345,8 +345,6 @@ def try_creating_indexing_task(
db_session=db_session,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
@@ -363,12 +361,11 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
fence_value.index_attempt_id = index_attempt_id
fence_value.celery_task_id = result.id
r.set(rci.fence_key, fence_value.model_dump_json())
except Exception:
redis_connector_index.set_fence(payload)
r.delete(rci.fence_key)
task_logger.exception(
f"Unexpected exception: "
f"tenant={tenant_id} "
@@ -391,12 +388,7 @@ def connector_indexing_proxy_task(
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing proxy - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
client = SimpleJobClient()
job = client.submit(
@@ -410,56 +402,29 @@ def connector_indexing_proxy_task(
)
if not job:
task_logger.info(
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
task_logger.info(
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
while True:
sleep(10)
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if job.status == "error":
task_logger.error(
f"Indexing proxy - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"error={job.exception()}"
)
if job.status == "error":
logger.error(job.exception())
job.release()
break
job.release()
break
task_logger.info(
f"Indexing proxy - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
@@ -481,97 +446,78 @@ def connector_indexing_task(
Returns None if the task did not run (possibly due to a conflict).
Otherwise, returns an int >= 0 representing the number of indexed docs.
NOTE: if an exception is raised out of this task, the primary worker will detect
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
This will cause the primary worker to abort the indexing attempt and clean up.
"""
# Since connector_indexing_proxy_task spawns a new process using this function as
# the entrypoint, we init Sentry here.
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
attempt_found = False
n_final_progress: int | None = None
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
attempt = None
n_final_progress = 0
r = get_redis_client(tenant_id=tenant_id)
if redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair_id)
if r.exists(rcd.fence_key):
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
f"fence={rcd.fence_key}"
)
if redis_connector.stop.fenced:
rcs = RedisConnectorStop(cc_pair_id)
if r.exists(rcs.fence_key):
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
f"fence={rcs.fence_key}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
while True:
# wait for the fence to come up
if not redis_connector_index.fenced:
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
)
payload = redis_connector_index.payload
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
f"connector_indexing_task: fence_data not decodeable: fence={rci.fence_key}"
)
raise
if payload.index_attempt_id is None or payload.celery_task_id is None:
logger.info(
f"connector_indexing_task - Waiting for fence: fence={redis_connector_index.fence_key}"
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
task_logger.info(
f"connector_indexing_task - Waiting for fence: fence={rci.fence_key}"
)
sleep(1)
continue
if payload.index_attempt_id != index_attempt_id:
raise ValueError(
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
f"task_index_attempt={index_attempt_id} "
f"payload_index_attempt={payload.index_attempt_id}"
)
logger.info(
f"connector_indexing_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
task_logger.info(
f"connector_indexing_task - Fence found, continuing...: fence={rci.fence_key}"
)
break
lock = r.lock(
redis_connector_index.generator_lock_key,
rci.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
logger.warning(
task_logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
payload.started = datetime.now(timezone.utc)
redis_connector_index.set_fence(payload)
fence_data.started = datetime.now(timezone.utc)
r.set(rci.fence_key, fence_data.model_dump_json())
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -580,7 +526,6 @@ def connector_indexing_task(
raise ValueError(
f"Index attempt not found: index_attempt={index_attempt_id}"
)
attempt_found = True
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
@@ -600,52 +545,43 @@ def connector_indexing_task(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
)
# define a callback class
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,
r,
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
logger.info(
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# define a callback class
callback = RunIndexingCallback(
rcs.fence_key, rci.generator_progress_key, lock, r
)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
# get back the total number of indexed docs and return it
n_final_progress = redis_connector_index.get_progress()
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
# get back the total number of indexed docs and return it
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
n_final_progress = int(cast(int, generator_progress_value))
except ValueError:
pass
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if attempt_found:
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
if attempt:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)
r.delete(rci.taskset_key)
r.delete(rci.fence_key)
raise e
finally:
if lock.owned():
lock.release()
logger.info(
f"Indexing spawned task finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return n_final_progress

View File

@@ -11,6 +11,9 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
@@ -30,7 +33,6 @@ from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import pruning_ctx
from danswer.utils.logger import setup_logger
@@ -145,11 +147,8 @@ def try_creating_prune_generator_task(
is used to trigger prunes immediately, e.g. via the web ui.
"""
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not ALLOW_SIMULTANEOUS_PRUNING:
count = redis_connector.prune.get_active_task_count()
if count > 0:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
return None
LOCK_TIMEOUT = 30
@@ -166,10 +165,15 @@ def try_creating_prune_generator_task(
return None
try:
if redis_connector.prune.fenced: # skip pruning if already pruning
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting
# skip pruning if the cc_pair is deleting
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
@@ -177,10 +181,10 @@ def try_creating_prune_generator_task(
return None
# add a long running generator task to the queue
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
@@ -196,7 +200,7 @@ def try_creating_prune_generator_task(
)
# set this only after all tasks have been added
redis_connector.prune.set_fence(True)
r.set(rcp.fence_key, 1)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
@@ -231,12 +235,12 @@ def connector_pruning_generator_task(
pruning_ctx_dict["request_id"] = self.request.id
pruning_ctx.set(pruning_ctx_dict)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcp = RedisConnectorPruning(cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
)
@@ -269,11 +273,10 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
rcs = RedisConnectorStop(cc_pair_id)
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,
r,
rcs.fence_key, rcp.generator_progress_key, lock, r
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
@@ -300,29 +303,31 @@ def connector_pruning_generator_task(
f"doc_source={cc_pair.connector.source}"
)
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
f"RedisConnectorPruning.generate_tasks starting. cc_pair={cc_pair.id}"
)
tasks_generated = redis_connector.prune.generate_tasks(
set(doc_ids_to_remove), self.app, db_session, None
tasks_generated = rcp.generate_tasks(
self.app, db_session, r, None, tenant_id
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.prune.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair={cc_pair.id} tasks_generated={tasks_generated}"
)
redis_connector.prune.generator_complete = tasks_generated
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
redis_connector.prune.set_fence(False)
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e
finally:
if lock.owned():

View File

@@ -0,0 +1,8 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime

View File

@@ -0,0 +1,10 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None

View File

@@ -19,6 +19,18 @@ from tenacity import RetryError
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import celery_get_queue_length
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from danswer.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
@@ -55,14 +67,7 @@ from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
@@ -187,7 +192,7 @@ def try_generate_stale_document_sync_tasks(
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
@@ -223,10 +228,10 @@ def try_generate_document_set_sync_tasks(
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(tenant_id, document_set_id)
rds = RedisDocumentSet(document_set_id)
# don't generate document set sync tasks if tasks are still pending
if rds.fenced:
if r.exists(rds.fence_key):
return None
# don't generate sync tasks if we're up to date
@@ -264,7 +269,7 @@ def try_generate_document_set_sync_tasks(
)
# set this only after all tasks have been added
rds.set_fence(tasks_generated)
r.set(rds.fence_key, tasks_generated)
return tasks_generated
@@ -278,9 +283,10 @@ def try_generate_user_group_sync_tasks(
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(tenant_id, usergroup_id)
if rug.fenced:
# don't generate sync tasks if tasks are still pending
rug = RedisUserGroup(usergroup_id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
# race condition with the monitor/cleanup function if we use a cached result!
@@ -320,7 +326,7 @@ def try_generate_user_group_sync_tasks(
)
# set this only after all tasks have been added
rug.set_fence(tasks_generated)
r.set(rug.fence_key, tasks_generated)
return tasks_generated
@@ -346,7 +352,7 @@ def monitor_connector_taskset(r: Redis) -> None:
def monitor_document_set_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
@@ -356,12 +362,16 @@ def monitor_document_set_taskset(
document_set_id = int(document_set_id_str)
rds = RedisDocumentSet(tenant_id, document_set_id)
if not rds.fenced:
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
return
initial_count = rds.payload
if initial_count is None:
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rds.taskset_key))
@@ -389,38 +399,48 @@ def monitor_document_set_taskset(
f"Successfully synced document set: document_set={document_set_id}"
)
rds.reset()
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis
key_bytes: bytes, r: Redis, tenant_id: str | None
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcd = RedisConnectorDeletion(cc_pair_id)
fence_data = redis_connector.delete.payload
if not fence_data:
task_logger.warning(
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rcd.fence_key))
if fence_value is None:
return
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_json)
)
return
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
# the fence is setting up but isn't ready yet
if fence_data.num_tasks is None:
# the fence is setting up but isn't ready yet
return
remaining = redis_connector.delete.get_remaining()
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
)
if remaining > 0:
if count > 0:
return
with get_session_with_tenant(tenant_id) as db_session:
@@ -504,15 +524,15 @@ def monitor_connector_deletion_taskset(
f"docs_deleted={fence_data.num_tasks}"
)
redis_connector.delete.taskset_clear()
redis_connector.delete.set_fence(None)
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
@@ -521,37 +541,46 @@ def monitor_ccpair_pruning_taskset(
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.prune.fenced:
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
if fence_value is None:
return
initial = redis_connector.prune.generator_complete
if initial is None:
generator_value = r.get(rcp.generator_complete_key)
if generator_value is None:
return
remaining = redis_connector.prune.get_remaining()
try:
initial_count = int(cast(int, generator_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcp.taskset_key))
task_logger.info(
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if remaining > 0:
if count > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)
r.delete(rcp.taskset_key)
r.delete(rcp.generator_progress_key)
r.delete(rcp.generator_complete_key)
r.delete(rcp.fence_key)
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
@@ -566,37 +595,53 @@ def monitor_ccpair_indexing_taskset(
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if not redis_connector_index.fenced:
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
return
payload = redis_connector_index.payload
if not payload:
return
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
if payload.index_attempt_id is None or payload.celery_task_id is None:
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
progress_count = int(cast(int, generator_progress_value))
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress_count} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
except ValueError:
task_logger.error(
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
)
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state
status_int = redis_connector_index.get_completion()
if status_int is None:
generator_complete_value = r.get(rci.generator_complete_key)
if generator_complete_value is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
@@ -607,18 +652,30 @@ def monitor_ccpair_indexing_taskset(
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
index_attempt=index_attempt,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
)
redis_connector_index.reset()
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
return
status_enum = HTTPStatus(status_int)
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
try:
status_value = int(cast(int, generator_complete_value))
status_enum = HTTPStatus(status_value)
except ValueError:
task_logger.error(
f"monitor_ccpair_indexing_taskset: "
f"generator_complete_value=f{generator_complete_value} could not be parsed."
)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
@@ -627,7 +684,11 @@ def monitor_ccpair_indexing_taskset(
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
redis_connector_index.reset()
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
@@ -639,7 +700,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
Returns True if the task actually did work, False if it exited early to prevent overlap
Returns True if the task actually did work, False
"""
r = get_redis_client(tenant_id=tenant_id)
@@ -690,33 +751,27 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
rci = RedisConnectorIndexing(
a.connector_credential_pair_id, a.search_settings_id
)
if not r.exists(fence_key):
failure_reason = (
f"Unknown index attempt. Might be left over from a process restart: "
f"index_attempt={a.id} "
f"cc_pair={a.connector_credential_pair_id} "
f"search_settings={a.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
if not r.exists(rci.fence_key):
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
lock_beat.reacquire()
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
monitor_document_set_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
@@ -727,19 +782,19 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
noop_fallback,
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
monitor_usergroup_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client

View File

@@ -1,6 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = celery_app
app = fetch_versioned_implementation(
"danswer.background.celery.apps.beat", "celery_app"
)

View File

@@ -118,13 +118,7 @@ def _run_indexing(
"""
start_time = time.time()
if index_attempt.search_settings is None:
raise ValueError(
"Search settings must be set for indexing. This should not be possible."
)
search_settings = index_attempt.search_settings
index_name = search_settings.index_name
# Only update cc-pair status for primary index jobs
@@ -337,7 +331,7 @@ def _run_indexing(
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
index_attempt,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
@@ -372,7 +366,7 @@ def _run_indexing(
and index_attempt_md.num_exceptions >= batch_num
):
mark_attempt_failed(
index_attempt.id,
index_attempt,
db_session,
failure_reason="All batches exceptioned.",
)

View File

@@ -1,4 +0,0 @@
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"

View File

@@ -10,7 +10,7 @@ from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
from danswer.tools.custom.base_tool_types import ToolResultType
class LlmDoc(BaseModel):
@@ -156,7 +156,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
error_msg: str | None = None
class FileChatDisplay(BaseModel):
class ImageGenerationDisplay(BaseModel):
file_ids: list[str]
@@ -170,7 +170,7 @@ AnswerQuestionPossibleReturn = (
| DanswerQuotes
| CitationInfo
| DanswerContexts
| FileChatDisplay
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo

View File

@@ -42,14 +42,18 @@ personas:
display_priority: 1
is_visible: true
starter_messages:
- name: "Give me an overview of what's here"
message: "Sample some documents and tell me what you find."
- name: "Use AI to solve a work related problem"
message: "Ask me what problem I would like to solve, then search the knowledge base to help me find a solution."
- name: "Find updates on a topic of interest"
message: "Once I provide a topic, retrieve related documents and tell me when there was last activity on the topic if available."
- name: "Surface contradictions"
message: "Have me choose a subject. Once I have provided it, check against the knowledge base and point out any inconsistencies. For all your following responses, focus on identifying contradictions."
- name: "General Information"
description: "Ask about available information"
message: "Hello! I'm interested in learning more about the information available here. Could you give me an overview of the types of data or documents that might be accessible?"
- name: "Specific Topic Search"
description: "Search for specific information"
message: "Hi! I'd like to learn more about a specific topic. Could you help me find relevant documents and information?"
- name: "Recent Updates"
description: "Inquire about latest additions"
message: "Hello! I'm curious about any recent updates or additions to the knowledge base. Can you tell me what new information has been added lately?"
- name: "Cross-referencing Information"
description: "Connect information from different sources"
message: "Hi! I'm working on a project that requires connecting information from multiple sources. How can I effectively cross-reference data across different documents or categories?"
- id: 1
name: "General"
@@ -67,14 +71,18 @@ personas:
display_priority: 0
is_visible: true
starter_messages:
- name: "Summarize a document"
message: "If I have provided a document please summarize it for me. If not, please ask me to upload a document either by dragging it into the input bar or clicking the +file icon."
- name: "Help me with coding"
message: 'Write me a "Hello World" script in 5 random languages to show off the functionality.'
- name: "Draft a professional email"
message: "Help me craft a professional email. Let's establish the context and the anticipated outcomes of the email before proposing a draft."
- name: "Learn something new"
message: "What is the difference between a Gantt chart, a Burndown chart and a Kanban board?"
- name: "Open Discussion"
description: "Start an open-ended conversation"
message: "Hi! Can you help me write a professional email?"
- name: "Problem Solving"
description: "Get help with a challenge"
message: "Hello! I need help managing my daily tasks better. Do you have any simple tips?"
- name: "Learn Something New"
description: "Explore a new topic"
message: "Hi! Could you explain what project management is in simple terms?"
- name: "Creative Brainstorming"
description: "Generate creative ideas"
message: "Hello! I need to brainstorm some team building activities. Do you have any fun suggestions?"
- id: 2
name: "Paraphrase"
@@ -93,12 +101,16 @@ personas:
is_visible: false
starter_messages:
- name: "Document Search"
description: "Find exact information"
message: "Hi! Could you help me find information about our team structure and reporting lines from our internal documents?"
- name: "Process Verification"
description: "Find exact quotes"
message: "Hello! I need to understand our project approval process. Could you find the exact steps from our documentation?"
- name: "Technical Documentation"
description: "Search technical details"
message: "Hi there! I'm looking for information about our deployment procedures. Can you find the specific steps from our technical guides?"
- name: "Policy Reference"
description: "Check official policies"
message: "Hello! Could you help me find our official guidelines about client communication? I need the exact wording from our documentation."
- id: 3
@@ -118,11 +130,15 @@ personas:
display_priority: 3
is_visible: true
starter_messages:
- name: "Create visuals for a presentation"
message: "Generate someone presenting a graph which clearly demonstrates an upwards trajectory."
- name: "Find inspiration for a marketing campaign"
message: "Generate an image of two happy individuals sipping on a soda drink in a glass bottle."
- name: "Visualize a product design"
message: "I want to add a search bar to my Iphone app. Generate me generic examples of how other apps implement this."
- name: "Generate a humorous image response"
message: "My teammate just made a silly mistake and I want to respond with a facepalm. Can you generate me one?"
- name: "Landscape"
description: "Generate a landscape image"
message: "Create an image of a serene mountain lake at sunset, with snow-capped peaks reflected in the calm water and a small wooden cabin on the shore."
- name: "Character"
description: "Generate a character image"
message: "Generate an image of a futuristic robot with glowing blue eyes, sleek metallic body, and intricate circuitry visible through transparent panels on its chest and arms."
- name: "Abstract"
description: "Create an abstract image"
message: "Create an abstract image representing the concept of time, using swirling clock hands, fragmented hourglasses, and streaks of light to convey the passage of moments and eras."
- name: "Urban Scene"
description: "Generate an urban landscape"
message: "Generate an image of a bustling futuristic cityscape at night, with towering skyscrapers, flying vehicles, holographic advertisements, and a mix of neon and bioluminescent lighting."

View File

@@ -11,18 +11,23 @@ from danswer.chat.models import AllCitations
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import FileChatDisplay
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -35,6 +40,7 @@ from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
@@ -54,13 +60,14 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
@@ -69,48 +76,36 @@ from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_constructor import construct_tools
from danswer.tools.tool_constructor import CustomToolConfig
from danswer.tools.tool_constructor import ImageGenerationToolConfig
from danswer.tools.tool_constructor import InternetSearchToolConfig
from danswer.tools.tool_constructor import SearchToolConfig
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from danswer.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
from danswer.tools.internet_search.internet_search_tool import (
internet_search_response_to_search_docs,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchResponse,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from danswer.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@@ -261,11 +256,10 @@ ChatPacket = (
| DanswerAnswerPiece
| AllCitations
| CitationInfo
| FileChatDisplay
| ImageGenerationDisplay
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
| StreamStopInfo
)
ChatPacketStream = Iterator[ChatPacket]
@@ -281,6 +275,7 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -292,9 +287,6 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
@@ -416,20 +408,12 @@ def stream_chat_message_objects(
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if existing_assistant_message_id is None:
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
else:
if final_msg.id != existing_assistant_message_id:
raise RuntimeError(
"The last message was not the existing assistant message. "
f"Final message id: {final_msg.id}, "
f"existing assistant message id: {existing_assistant_message_id}"
)
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
@@ -500,19 +484,13 @@ def stream_chat_message_objects(
),
max_window_percentage=max_document_percentage,
)
# we don't need to reserve a message id if we're using an existing assistant message
reserved_message_id = (
final_msg.id
if existing_assistant_message_id is not None
else reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
@@ -527,13 +505,7 @@ def stream_chat_message_objects(
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
# if we're using an existing assistant message, then this will just be an
# update operation, in which case the parent should be the parent of
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
),
parent_message=final_msg,
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
@@ -545,7 +517,6 @@ def stream_chat_message_objects(
# reference_docs=,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
)
if not final_msg.prompt:
@@ -561,53 +532,148 @@ def stream_chat_message_objects(
if not persona
else PromptConfig.from_model(persona.prompts[0])
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
)
tool_dict = construct_tools(
persona=persona,
prompt_config=prompt_config,
db_session=db_session,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
latest_query_files=latest_query_files,
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
),
image_generation_tool_config=ImageGenerationToolConfig(
additional_headers=litellm_additional_headers,
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
additional_headers=custom_tool_additional_headers,
),
)
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
elif (
llm.config.model_provider == "azure"
and AZURE_DALLE_API_KEY is not None
):
img_generation_llm_config = LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
model=img_generation_llm_config.model_name,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(api_key=bing_api_key)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
)
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=answer_style_config,
answer_style_config=AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
),
prompt_config=prompt_config,
llm=(
llm
@@ -675,6 +741,7 @@ def stream_chat_message_objects(
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
@@ -692,7 +759,7 @@ def stream_chat_message_objects(
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield FileChatDisplay(
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
@@ -706,32 +773,11 @@ def stream_chat_message_objects(
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
if (
custom_tool_response.response_type == "image"
or custom_tool_response.response_type == "csv"
):
file_ids = custom_tool_response.tool_result.file_ids
ai_message_files = [
FileDescriptor(
id=str(file_id),
type=ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV,
)
for file_id in file_ids
]
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
else:
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif isinstance(packet, StreamStopInfo):
pass
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
@@ -761,7 +807,6 @@ def stream_chat_message_objects(
# Post-LLM answer processing
try:
logger.debug("Post-LLM answer processing")
message_specific_citations: MessageSpecificCitations | None = None
if reference_db_search_docs:
message_specific_citations = _translate_citations(
@@ -777,6 +822,7 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@@ -784,21 +830,21 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=(
message_specific_citations.citation_map
if message_specific_citations
else None
),
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
error=None,
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else None
else []
),
)
@@ -822,6 +868,7 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -831,6 +878,7 @@ def stream_chat_message(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,

View File

@@ -9,19 +9,19 @@ prompts:
system: >
You are a question answering system that is constantly learning and improving.
The current date is DANSWER_DATETIME_REPLACEMENT.
You can process and comprehend vast amounts of text and utilize this knowledge to provide
grounded, accurate, and concise answers to diverse queries.
You always clearly communicate ANY UNCERTAINTY in your answer.
# Task Prompt (as shown in UI)
task: >
Answer my query based on the documents provided.
The documents may not all be relevant, ignore any documents that are not directly relevant
to the most recent user query.
I have not read or seen any of the documents and do not want to read them.
If there are no relevant documents, refer to the chat history and your internal knowledge.
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
@@ -30,21 +30,21 @@ prompts:
# Prompts the LLM to include citations in the for [1], [2] etc.
# which get parsed to match the passed in sources
include_citations: true
- name: "ImageGeneration"
description: "Generates images from user descriptions!"
description: "Generates images based on user prompts!"
system: >
You are an AI image generation assistant. Your role is to create high-quality images based on user descriptions.
For appropriate requests, you will generate an image that matches the user's requirements.
For inappropriate or unsafe requests, you will politely decline and explain why the request cannot be fulfilled.
You aim to be helpful while maintaining appropriate content standards.
You are an advanced image generation system capable of creating diverse and detailed images.
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
task: >
Based on the user's description, create a high-quality image that accurately reflects their request.
Pay close attention to the specified details, styles, and desired elements.
If the request is not appropriate or cannot be fulfilled, explain why and suggest alternatives.
Generate an image based on the user's description.
Provide a detailed description of the generated image, including key elements, colors, and composition.
If the request is not possible or appropriate, explain why and suggest alternatives.
datetime_aware: true
include_citations: false
@@ -64,13 +64,14 @@ prompts:
datetime_aware: true
include_citations: true
- name: "Summarize"
description: "Summarize relevant information from retrieved context!"
system: >
You are a text summarizing assistant that highlights the most important knowledge from the
context provided, prioritizing the information that relates to the user query.
The current date is DANSWER_DATETIME_REPLACEMENT.
You ARE NOT creative and always stick to the provided documents.
If there are no documents, refer to the conversation history.
@@ -83,6 +84,7 @@ prompts:
datetime_aware: true
include_citations: true
- name: "Paraphrase"
description: "Recites information from retrieved context! Least creative but most safe!"
system: >
@@ -90,10 +92,10 @@ prompts:
The current date is DANSWER_DATETIME_REPLACEMENT.
You only provide quotes that are EXACT substrings from provided documents!
If there are no documents provided,
simply tell the user that there are no documents to reference.
You NEVER generate new text or phrases outside of the citation.
DO NOT explain your responses, only provide the quotes and NOTHING ELSE.
task: >

View File

@@ -163,17 +163,6 @@ try:
except ValueError:
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
# Experimental setting to control idle transactions
POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT = 0 # milliseconds
try:
POSTGRES_IDLE_SESSIONS_TIMEOUT = int(
os.environ.get(
"POSTGRES_IDLE_SESSIONS_TIMEOUT", POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
)
)
except ValueError:
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
@@ -262,6 +251,9 @@ ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
# for some connectors
ENABLE_EXPENSIVE_EXPERT_CALLS = False
GOOGLE_DRIVE_INCLUDE_SHARED = False
GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False
GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False
# TODO these should be available for frontend configuration, via advanced options expandable
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
@@ -489,17 +481,3 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
#####
# API Key Configs
#####
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)

View File

@@ -125,8 +125,6 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage"
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
FRESHDESK = "freshdesk"
FIREFLIES = "fireflies"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -226,9 +224,6 @@ class DanswerRedisLocks:
PRUNING_LOCK_PREFIX = "da_lock:pruning"
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
SLACK_BOT_LOCK = "da_lock:slack_bot"
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
class DanswerCeleryPriority(int, Enum):
HIGHEST = 0

View File

@@ -17,7 +17,6 @@ from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
@@ -250,11 +249,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
return self._fetch_document_batches()
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")

View File

@@ -23,16 +23,7 @@ def datetime_to_utc(dt: datetime) -> datetime:
def time_str_to_utc(datetime_str: str) -> datetime:
try:
dt = parse(datetime_str)
except ValueError:
# Handle malformed timezone by attempting to fix common format issues
if "0000" in datetime_str:
# Convert "0000" to "+0000" for proper timezone parsing
fixed_dt_str = datetime_str.replace(" 0000", " +0000")
dt = parse(fixed_dt_str)
else:
raise
dt = parse(datetime_str)
return datetime_to_utc(dt)

View File

@@ -16,8 +16,6 @@ from danswer.connectors.discourse.connector import DiscourseConnector
from danswer.connectors.document360.connector import Document360Connector
from danswer.connectors.dropbox.connector import DropboxConnector
from danswer.connectors.file.connector import LocalFileConnector
from danswer.connectors.fireflies.connector import FirefliesConnector
from danswer.connectors.freshdesk.connector import FreshdeskConnector
from danswer.connectors.github.connector import GithubConnector
from danswer.connectors.gitlab.connector import GitlabConnector
from danswer.connectors.gmail.connector import GmailConnector
@@ -101,8 +99,6 @@ def identify_connector_class(
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -27,8 +27,8 @@ from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -123,13 +123,9 @@ def _process_file(
"filename",
"file_display_name",
"title",
"connector_type",
]
}
source_type_str = all_metadata.get("connector_type")
source_type = DocumentSource(source_type_str) if source_type_str else None
p_owner_names = all_metadata.get("primary_owners")
s_owner_names = all_metadata.get("secondary_owners")
p_owners = (
@@ -149,7 +145,7 @@ def _process_file(
sections=[
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=source_type or DocumentSource.FILE,
source=DocumentSource.FILE,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,

View File

@@ -1,182 +0,0 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import List
import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
_FIREFLIES_ID_PREFIX = "FIREFLIES_"
_FIREFLIES_API_URL = "https://api.fireflies.ai/graphql"
_FIREFLIES_TRANSCRIPT_QUERY_SIZE = 50 # Max page size is 50
_FIREFLIES_API_QUERY = """
query Transcripts($fromDate: DateTime, $toDate: DateTime, $limit: Int!, $skip: Int!) {
transcripts(fromDate: $fromDate, toDate: $toDate, limit: $limit, skip: $skip) {
id
title
host_email
participants
date
transcript_url
sentences {
text
speaker_name
}
}
}
"""
def _create_doc_from_transcript(transcript: dict) -> Document | None:
meeting_text = ""
sentences = transcript.get("sentences", [])
if sentences:
for sentence in sentences:
meeting_text += sentence.get("speaker_name") or "Unknown Speaker"
meeting_text += ": " + sentence.get("text", "") + "\n\n"
else:
return None
meeting_link = transcript["transcript_url"]
fireflies_id = _FIREFLIES_ID_PREFIX + transcript["id"]
meeting_title = transcript["title"] or "No Title"
meeting_date_unix = transcript["date"]
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
meeting_host_email = transcript["host_email"]
host_email_user_info = [BasicExpertInfo(email=meeting_host_email)]
meeting_participants_email_list = []
for participant in transcript.get("participants", []):
if participant != meeting_host_email and participant:
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
return Document(
id=fireflies_id,
sections=[
Section(
link=meeting_link,
text=meeting_text,
)
],
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},
doc_updated_at=meeting_date,
primary_owners=host_email_user_info,
secondary_owners=meeting_participants_email_list,
)
class FirefliesConnector(PollConnector, LoadConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, str]) -> None:
api_key = credentials.get("fireflies_api_key")
if not isinstance(api_key, str):
raise ConnectorMissingCredentialError(
"The Fireflies API key must be a string"
)
self.api_key = api_key
return None
def _fetch_transcripts(
self, start_datetime: str | None = None, end_datetime: str | None = None
) -> Iterator[List[dict]]:
if self.api_key is None:
raise ConnectorMissingCredentialError("Missing API key")
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
}
skip = 0
variables: dict[str, int | str] = {
"limit": _FIREFLIES_TRANSCRIPT_QUERY_SIZE,
}
if start_datetime:
variables["fromDate"] = start_datetime
if end_datetime:
variables["toDate"] = end_datetime
while True:
variables["skip"] = skip
response = requests.post(
_FIREFLIES_API_URL,
headers=headers,
json={"query": _FIREFLIES_API_QUERY, "variables": variables},
)
response.raise_for_status()
if response.status_code == 204:
break
recieved_transcripts = response.json()
parsed_transcripts = recieved_transcripts.get("data", {}).get(
"transcripts", []
)
yield parsed_transcripts
if len(parsed_transcripts) < _FIREFLIES_TRANSCRIPT_QUERY_SIZE:
break
skip += _FIREFLIES_TRANSCRIPT_QUERY_SIZE
def _process_transcripts(
self, start: str | None = None, end: str | None = None
) -> GenerateDocumentsOutput:
doc_batch: List[Document] = []
for transcript_batch in self._fetch_transcripts(start, end):
for transcript in transcript_batch:
if doc := _create_doc_from_transcript(transcript):
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._process_transcripts()
def poll_source(
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(
start_unixtime, tz=timezone.utc
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.000Z"
)
yield from self._process_transcripts(start_datetime, end_datetime)

View File

@@ -1,239 +0,0 @@
import json
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import List
import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.logger import setup_logger
logger = setup_logger()
_FRESHDESK_ID_PREFIX = "FRESHDESK_"
_TICKET_FIELDS_TO_INCLUDE = {
"fr_escalated",
"spam",
"priority",
"source",
"status",
"type",
"is_escalated",
"tags",
"nr_due_by",
"nr_escalated",
"cc_emails",
"fwd_emails",
"reply_cc_emails",
"ticket_cc_emails",
"support_email",
"to_emails",
}
_SOURCE_NUMBER_TYPE_MAP: dict[int, str] = {
1: "Email",
2: "Portal",
3: "Phone",
7: "Chat",
9: "Feedback Widget",
10: "Outbound Email",
}
_PRIORITY_NUMBER_TYPE_MAP: dict[int, str] = {
1: "low",
2: "medium",
3: "high",
4: "urgent",
}
_STATUS_NUMBER_TYPE_MAP: dict[int, str] = {
2: "open",
3: "pending",
4: "resolved",
5: "closed",
}
def _create_metadata_from_ticket(ticket: dict) -> dict:
metadata: dict[str, str | list[str]] = {}
# Combine all emails into a list so there are no repeated emails
email_data: set[str] = set()
for key, value in ticket.items():
# Skip fields that aren't useful for embedding
if key not in _TICKET_FIELDS_TO_INCLUDE:
continue
# Skip empty fields
if not value or value == "[]":
continue
# Convert strings or lists to strings
stringified_value: str | list[str]
if isinstance(value, list):
stringified_value = [str(item) for item in value]
else:
stringified_value = str(value)
if "email" in key:
if isinstance(stringified_value, list):
email_data.update(stringified_value)
else:
email_data.add(stringified_value)
else:
metadata[key] = stringified_value
if email_data:
metadata["emails"] = list(email_data)
# Convert source numbers to human-parsable string
if source_number := ticket.get("source"):
metadata["source"] = _SOURCE_NUMBER_TYPE_MAP.get(
source_number, "Unknown Source Type"
)
# Convert priority numbers to human-parsable string
if priority_number := ticket.get("priority"):
metadata["priority"] = _PRIORITY_NUMBER_TYPE_MAP.get(
priority_number, "Unknown Priority"
)
# Convert status to human-parsable string
if status_number := ticket.get("status"):
metadata["status"] = _STATUS_NUMBER_TYPE_MAP.get(
status_number, "Unknown Status"
)
due_by = datetime.fromisoformat(ticket["due_by"].replace("Z", "+00:00"))
metadata["overdue"] = str(datetime.now(timezone.utc) > due_by)
return metadata
def _create_doc_from_ticket(ticket: dict, domain: str) -> Document:
# Use the ticket description as the text
text = f"Ticket description: {parse_html_page_basic(ticket.get('description_text', ''))}"
metadata = _create_metadata_from_ticket(ticket)
# This is also used in the ID because it is more unique than the just the ticket ID
link = f"https://{domain}.freshdesk.com/helpdesk/tickets/{ticket['id']}"
return Document(
id=_FRESHDESK_ID_PREFIX + link,
sections=[
Section(
link=link,
text=text,
)
],
source=DocumentSource.FRESHDESK,
semantic_identifier=ticket["subject"],
metadata=metadata,
doc_updated_at=datetime.fromisoformat(
ticket["updated_at"].replace("Z", "+00:00")
),
)
class FreshdeskConnector(PollConnector, LoadConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, str | int]) -> None:
api_key = credentials.get("freshdesk_api_key")
domain = credentials.get("freshdesk_domain")
password = credentials.get("freshdesk_password")
if not all(isinstance(cred, str) for cred in [domain, api_key, password]):
raise ConnectorMissingCredentialError(
"All Freshdesk credentials must be strings"
)
self.api_key = str(api_key)
self.domain = str(domain)
self.password = str(password)
def _fetch_tickets(
self, start: datetime | None = None, end: datetime | None = None
) -> Iterator[List[dict]]:
"""
'end' is not currently used, so we may double fetch tickets created after the indexing
starts but before the actual call is made.
To use 'end' would require us to use the search endpoint but it has limitations,
namely having to fetch all IDs and then individually fetch each ticket because there is no
'include' field available for this endpoint:
https://developers.freshdesk.com/api/#filter_tickets
"""
if self.api_key is None or self.domain is None or self.password is None:
raise ConnectorMissingCredentialError("freshdesk")
base_url = f"https://{self.domain}.freshdesk.com/api/v2/tickets"
params: dict[str, int | str] = {
"include": "description",
"per_page": 50,
"page": 1,
}
if start:
params["updated_since"] = start.isoformat()
while True:
response = requests.get(
base_url, auth=(self.api_key, self.password), params=params
)
response.raise_for_status()
if response.status_code == 204:
break
tickets = json.loads(response.content)
logger.info(
f"Fetched {len(tickets)} tickets from Freshdesk API (Page {params['page']})"
)
yield tickets
if len(tickets) < int(params["per_page"]):
break
params["page"] = int(params["page"]) + 1
def _process_tickets(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
doc_batch: List[Document] = []
for ticket_batch in self._fetch_tickets(start, end):
for ticket in ticket_batch:
doc_batch.append(_create_doc_from_ticket(ticket, self.domain))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._process_tickets()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
yield from self._process_tickets(start_datetime, end_datetime)

View File

@@ -1,360 +1,221 @@
from base64 import urlsafe_b64decode
from typing import Any
from typing import cast
from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient import discovery # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.google_utils.google_auth import get_google_creds
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
from danswer.connectors.google_utils.resources import get_admin_service
from danswer.connectors.google_utils.resources import get_gmail_service
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
from danswer.connectors.gmail.connector_auth import (
get_gmail_creds_for_authorized_user,
)
from danswer.connectors.gmail.connector_auth import (
get_gmail_creds_for_service_account,
)
from danswer.connectors.gmail.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
from danswer.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
from danswer.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
from danswer.connectors.google_utils.shared_constants import USER_FIELDS
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
# This is for the initial list call to get the thread ids
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
# These are the fields to retrieve using the ID from the initial list call
PARTS_FIELDS = "parts(body(data), mimeType)"
PAYLOAD_FIELDS = f"payload(headers, {PARTS_FIELDS})"
MESSAGES_FIELDS = f"messages(id, {PAYLOAD_FIELDS})"
THREADS_FIELDS = f"threads(id, {MESSAGES_FIELDS})"
THREAD_FIELDS = f"id, {MESSAGES_FIELDS}"
EMAIL_FIELDS = [
"cc",
"bcc",
"from",
"to",
]
add_retries = retry_builder(tries=50, max_delay=30)
def _build_time_range_query(
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> str | None:
query = ""
if time_range_start is not None and time_range_start != 0:
query += f"after:{int(time_range_start)}"
if time_range_end is not None and time_range_end != 0:
query += f" before:{int(time_range_end)}"
query = query.strip()
if len(query) == 0:
return None
return query
def _clean_email_and_extract_name(email: str) -> tuple[str, str | None]:
email = email.strip()
if "<" in email and ">" in email:
# Handle format: "Display Name <email@domain.com>"
display_name = email[: email.find("<")].strip()
email_address = email[email.find("<") + 1 : email.find(">")].strip()
return email_address, display_name if display_name else None
else:
# Handle plain email address
return email.strip(), None
def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertInfo]:
owners = []
for email, names in emails.items():
if names:
name_parts = names.split(" ")
first_name = " ".join(name_parts[:-1])
last_name = name_parts[-1]
else:
first_name = None
last_name = None
owners.append(
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
)
return owners
def _get_message_body(payload: dict[str, Any]) -> str:
parts = payload.get("parts", [])
message_body = ""
for part in parts:
mime_type = part.get("mimeType")
body = part.get("body")
if mime_type == "text/plain" and body:
data = body.get("data", "")
text = urlsafe_b64decode(data).decode()
message_body += text
return message_body
def message_to_section(message: Dict[str, Any]) -> tuple[Section, dict[str, str]]:
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
payload = message.get("payload", {})
headers = payload.get("headers", [])
metadata: dict[str, Any] = {}
for header in headers:
name = header.get("name").lower()
value = header.get("value")
if name in EMAIL_FIELDS:
metadata[name] = value
if name == "subject":
metadata["subject"] = value
if name == "date":
metadata["updated_at"] = value
if labels := message.get("labelIds"):
metadata["labels"] = labels
message_data = ""
for name, value in metadata.items():
# updated at isnt super useful for the llm
if name != "updated_at":
message_data += f"{name}: {value}\n"
message_body_text: str = _get_message_body(payload)
return Section(link=link, text=message_body_text + message_data), metadata
def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
all_messages = full_thread.get("messages", [])
if not all_messages:
return None
sections = []
semantic_identifier = ""
updated_at = None
from_emails: dict[str, str | None] = {}
other_emails: dict[str, str | None] = {}
for message in all_messages:
section, message_metadata = message_to_section(message)
sections.append(section)
for name, value in message_metadata.items():
if name in EMAIL_FIELDS:
email, display_name = _clean_email_and_extract_name(value)
if name == "from":
from_emails[email] = (
display_name if not from_emails.get(email) else None
)
else:
other_emails[email] = (
display_name if not other_emails.get(email) else None
)
# If we haven't set the semantic identifier yet, set it to the subject of the first message
if not semantic_identifier:
semantic_identifier = message_metadata.get("subject", "")
if message_metadata.get("updated_at"):
updated_at = message_metadata.get("updated_at")
updated_at_datetime = None
if updated_at:
updated_at_datetime = time_str_to_utc(updated_at)
id = full_thread.get("id")
if not id:
raise ValueError("Thread ID is required")
primary_owners = _get_owners_from_emails(from_emails)
secondary_owners = _get_owners_from_emails(other_emails)
return Document(
id=id,
semantic_identifier=semantic_identifier,
sections=sections,
source=DocumentSource.GMAIL,
# This is used to perform permission sync
primary_owners=primary_owners,
secondary_owners=secondary_owners,
doc_updated_at=updated_at_datetime,
# Not adding emails to metadata because it's already in the sections
metadata={},
)
class GmailConnector(LoadConnector, PollConnector, SlimConnector):
class GmailConnector(LoadConnector, PollConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._primary_admin_email: str | None = None
@property
def primary_admin_email(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email
@property
def google_domain(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email.split("@")[-1]
@property
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
if self._creds is None:
raise RuntimeError(
"Creds missing, "
"should not call this property "
"before calling load_credentials"
)
return self._creds
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._primary_admin_email = primary_admin_email
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorugh
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds: OAuthCredentials | ServiceAccountCredentials | None = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
)
creds = get_gmail_creds_for_authorized_user(
token_json_str=access_token_json_str
)
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,
source=DocumentSource.GMAIL,
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = creds.to_json() if creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
if GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
creds = get_gmail_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
if creds is None:
raise PermissionError(
"Unable to access Gmail - unknown credential structure."
)
self.creds = creds
return new_creds_dict
def _get_all_user_emails(self) -> list[str]:
admin_service = get_admin_service(self.creds, self.primary_admin_email)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
def _get_email_body(self, payload: dict[str, Any]) -> str:
parts = payload.get("parts", [])
email_body = ""
for part in parts:
mime_type = part.get("mimeType")
body = part.get("body")
if mime_type == "text/plain":
data = body.get("data", "")
text = urlsafe_b64decode(data).decode()
email_body += text
return email_body
def _fetch_threads(
def _email_to_document(self, full_email: Dict[str, Any]) -> Document:
email_id = full_email["id"]
payload = full_email["payload"]
headers = payload.get("headers")
labels = full_email.get("labelIds", [])
metadata = {}
if headers:
for header in headers:
name = header.get("name").lower()
value = header.get("value")
if name in ["from", "to", "subject", "date", "cc", "bcc"]:
metadata[name] = value
email_data = ""
for name, value in metadata.items():
email_data += f"{name}: {value}\n"
metadata["labels"] = labels
logger.debug(f"{email_data}")
email_body_text: str = self._get_email_body(payload)
date_str = metadata.get("date")
email_updated_at = time_str_to_utc(date_str) if date_str else None
link = f"https://mail.google.com/mail/u/0/#inbox/{email_id}"
return Document(
id=email_id,
sections=[Section(link=link, text=email_data + email_body_text)],
source=DocumentSource.GMAIL,
title=metadata.get("subject"),
semantic_identifier=metadata.get("subject", "Untitled Email"),
doc_updated_at=email_updated_at,
metadata=metadata,
)
@staticmethod
def _build_time_range_query(
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> str | None:
query = ""
if time_range_start is not None and time_range_start != 0:
query += f"after:{int(time_range_start)}"
if time_range_end is not None and time_range_end != 0:
query += f" before:{int(time_range_end)}"
query = query.strip()
if len(query) == 0:
return None
return query
def _fetch_mails_from_gmail(
self,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = []
for user_email in self._get_all_user_emails():
gmail_service = get_gmail_service(self.creds, user_email)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
):
full_threads = execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().get,
list_key=None,
userId=user_email,
fields=THREAD_FIELDS,
id=thread["id"],
if self.creds is None:
raise PermissionError("Not logged into Gmail")
page_token = ""
query = GmailConnector._build_time_range_query(time_range_start, time_range_end)
service = discovery.build("gmail", "v1", credentials=self.creds)
while page_token is not None:
result = (
service.users()
.messages()
.list(
userId="me",
pageToken=page_token,
q=query,
maxResults=self.batch_size,
)
# full_threads is an iterator containing a single thread
# so we need to convert it to a list and grab the first element
full_thread = list(full_threads)[0]
doc = thread_to_document(full_thread)
if doc is None:
continue
.execute()
)
page_token = result.get("nextPageToken")
messages = result.get("messages", [])
doc_batch = []
for message in messages:
message_id = message["id"]
msg = (
service.users()
.messages()
.get(userId="me", id=message_id, format="full")
.execute()
)
doc = self._email_to_document(msg)
doc_batch.append(doc)
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def _fetch_slim_threads(
self,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = []
for user_email in self._get_all_user_emails():
gmail_service = get_gmail_service(self.creds, user_email)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
):
doc_batch.append(
SlimDocument(
id=thread["id"],
perm_sync_data={"user_email": user_email},
)
)
if len(doc_batch) > SLIM_BATCH_SIZE:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
if len(doc_batch) > 0:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
try:
yield from self._fetch_threads()
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
yield from self._fetch_mails_from_gmail()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
try:
yield from self._fetch_threads(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._fetch_slim_threads(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
yield from self._fetch_mails_from_gmail(start, end)
if __name__ == "__main__":
pass
import json
import os
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
if not service_account_json_path:
raise ValueError(
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
)
with open(service_account_json_path) as f:
creds = json.load(f)
credentials_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: json.dumps(creds),
}
delegated_user = os.environ.get("GMAIL_DELEGATED_USER")
if delegated_user:
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
connector = GmailConnector()
connector.load_credentials(
json.loads(credentials_dict[DB_CREDENTIALS_DICT_TOKEN_KEY])
)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print(document_batch)
break

View File

@@ -0,0 +1,197 @@
import json
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
from urllib.parse import urlparse
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GMAIL_CRED_KEY
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
from danswer.connectors.gmail.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.gmail.constants import SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _build_frontend_gmail_redirect() -> str:
return f"{WEB_DOMAIN}/admin/connectors/gmail/auth/callback"
def get_gmail_creds_for_authorized_user(
token_json_str: str,
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logger.notice("Refreshed Gmail tokens.")
return creds
except Exception as e:
logger.exception(f"Failed to refresh gmail access token due to: {e}")
return None
return None
def get_gmail_creds_for_service_account(
service_account_key_json_str: str,
) -> ServiceAccountCredentials | None:
service_account_key = json.loads(service_account_key_json_str)
creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=SCOPES
)
if not creds.valid or not creds.expired:
creds.refresh(Request())
return creds if creds.valid else None
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Gmail Connector callback does not match expected"
)
def get_gmail_auth_url(credential_id: int) -> str:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def update_gmail_credential_access_tokens(
auth_code: str,
credential_id: int,
user: User,
db_session: Session,
) -> OAuthCredentials | None:
app_credentials = get_google_app_gmail_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
token_json_str = creds.to_json()
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
return None
return creds
def build_service_account_creds(
delegated_user_email: str | None = None,
) -> CredentialBase:
service_account_key = get_gmail_service_account_key()
credential_dict = {
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
}
if delegated_user_email:
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
return CredentialBase(
source=DocumentSource.GMAIL,
credential_json=credential_dict,
admin_public=True,
)
def get_google_app_gmail_cred() -> GoogleAppCredentials:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
def delete_google_app_gmail_cred() -> None:
get_kv_store().delete(KV_GMAIL_CRED_KEY)
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_gmail_service_account_key(
service_account_key: GoogleServiceAccountKey,
) -> None:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_gmail_service_account_key() -> None:
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
def delete_service_account_key() -> None:
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)

View File

@@ -0,0 +1,4 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "gmail_tokens"
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "gmail_delegated_user"
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]

View File

@@ -1,400 +1,556 @@
from collections.abc import Callable
import io
from collections.abc import Iterator
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from enum import Enum
from itertools import chain
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient import discovery # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import GOOGLE_DRIVE_FOLLOW_SHORTCUTS
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_drive.doc_conversion import build_slim_document
from danswer.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
from danswer.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_utils.google_auth import get_google_creds
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
from danswer.connectors.google_utils.resources import get_admin_service
from danswer.connectors.google_utils.resources import get_drive_service
from danswer.connectors.google_utils.resources import get_google_docs_service
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
from danswer.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
from danswer.connectors.google_utils.shared_constants import SCOPE_DOC_URL
from danswer.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
from danswer.connectors.google_utils.shared_constants import USER_FIELDS
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
# All file retrievals could be batched and made at once
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
if not string:
return []
return [s.strip() for s in string.split(",") if s.strip()]
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
def _extract_ids_from_urls(urls: list[str]) -> list[str]:
return [url.split("/")[-1] for url in urls]
GoogleDriveFileType = dict[str, Any]
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
def _convert_single_file(
creds: Any, primary_admin_email: str, file: dict[str, Any]
) -> Any:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_drive_service = get_drive_service(creds, user_email=user_email)
docs_service = get_google_docs_service(creds, user_email=user_email)
return convert_drive_item_to_document(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
def _run_drive_file_query(
service: discovery.Resource,
query: str,
continue_on_failure: bool,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
next_page_token = ""
while next_page_token is not None:
logger.debug(f"Running Google Drive fetch with query: {query}")
results = add_retries(
lambda: (
service.files()
.list(
corpora="allDrives"
if include_shared
else "user", # needed to search through shared drives
pageSize=batch_size,
supportsAllDrives=include_shared,
includeItemsFromAllDrives=include_shared,
fields=(
"nextPageToken, files(mimeType, id, name, permissions, "
"modifiedTime, webViewLink, shortcutDetails)"
),
pageToken=next_page_token,
q=query,
)
.execute()
)
)()
next_page_token = results.get("nextPageToken")
files = results["files"]
for file in files:
if follow_shortcuts and "shortcutDetails" in file:
try:
file_shortcut_points_to = add_retries(
lambda: (
service.files()
.get(
fileId=file["shortcutDetails"]["targetId"],
supportsAllDrives=include_shared,
fields="mimeType, id, name, modifiedTime, webViewLink, permissions, shortcutDetails",
)
.execute()
)
)()
yield file_shortcut_points_to
except HttpError:
logger.error(
f"Failed to follow shortcut with details: {file['shortcutDetails']}"
)
if continue_on_failure:
continue
raise
else:
yield file
def _get_folder_id(
service: discovery.Resource,
parent_id: str,
folder_name: str,
include_shared: bool,
follow_shortcuts: bool,
) -> str | None:
"""
Get the ID of a folder given its name and the ID of its parent folder.
"""
query = f"'{parent_id}' in parents and name='{folder_name}' and "
if follow_shortcuts:
query += f"(mimeType='{DRIVE_FOLDER_TYPE}' or mimeType='{DRIVE_SHORTCUT_TYPE}')"
else:
query += f"mimeType='{DRIVE_FOLDER_TYPE}'"
# TODO: support specifying folder path in shared drive rather than just `My Drive`
results = add_retries(
lambda: (
service.files()
.list(
q=query,
spaces="drive",
fields="nextPageToken, files(id, name, shortcutDetails)",
supportsAllDrives=include_shared,
includeItemsFromAllDrives=include_shared,
)
.execute()
)
)()
items = results.get("files", [])
folder_id = None
if items:
if follow_shortcuts and "shortcutDetails" in items[0]:
folder_id = items[0]["shortcutDetails"]["targetId"]
else:
folder_id = items[0]["id"]
return folder_id
def _get_folders(
service: discovery.Resource,
continue_on_failure: bool,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
if follow_shortcuts:
query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
for file in _run_drive_file_query(
service=service,
query=query,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
):
# Need to check this since file may have been a target of a shortcut
# and not necessarily a folder
if file["mimeType"] == DRIVE_FOLDER_TYPE:
yield file
else:
pass
def _get_files(
service: discovery.Resource,
continue_on_failure: bool,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
if time_range_start is not None:
time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
query += f"and modifiedTime >= '{time_start}' "
if time_range_end is not None:
time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
query += f"and modifiedTime <= '{time_stop}' "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
files = _run_drive_file_query(
service=service,
query=query,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
def _process_files_batch(
files: list[GoogleDriveFileType], convert_func: Callable, batch_size: int
) -> GenerateDocumentsOutput:
doc_batch = []
with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor:
for doc in executor.map(convert_func, files):
if doc:
doc_batch.append(doc)
if len(doc_batch) >= batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
return files
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def get_all_files_batched(
service: discovery.Resource,
continue_on_failure: bool,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
# if True, will fetch files in sub-folders of the specified folder ID.
# Only applies if folder_id is specified.
traverse_subfolders: bool = True,
folder_ids_traversed: list[str] | None = None,
) -> Iterator[list[GoogleDriveFileType]]:
"""Gets all files matching the criteria specified by the args from Google Drive
in batches of size `batch_size`.
"""
found_files = _get_files(
service=service,
continue_on_failure=continue_on_failure,
time_range_start=time_range_start,
time_range_end=time_range_end,
folder_id=folder_id,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
yield from batch_generator(
items=found_files,
batch_size=batch_size,
pre_batch_yield=lambda batch_files: logger.debug(
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
),
)
if traverse_subfolders and folder_id is not None:
folder_ids_traversed = folder_ids_traversed or []
subfolders = _get_folders(
service=service,
folder_id=folder_id,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
for subfolder in subfolders:
if subfolder["id"] not in folder_ids_traversed:
logger.info("Fetching all files in subfolder: " + subfolder["name"])
folder_ids_traversed.append(subfolder["id"])
yield from get_all_files_batched(
service=service,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
time_range_start=time_range_start,
time_range_end=time_range_end,
folder_id=subfolder["id"],
traverse_subfolders=traverse_subfolders,
folder_ids_traversed=folder_ids_traversed,
)
else:
logger.debug(
"Skipping subfolder since already traversed: " + subfolder["name"]
)
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
mime_type = file["mimeType"]
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
return (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return unstructured_to_text(
file=io.BytesIO(response), file_name=file.get("name", file["id"])
)
if mime_type == GDriveMimeType.WORD_DOC.value:
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT
class GoogleDriveConnector(LoadConnector, PollConnector):
def __init__(
self,
include_shared_drives: bool = True,
shared_drive_urls: str | None = None,
include_my_drives: bool = True,
my_drive_emails: str | None = None,
shared_folder_urls: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
# OLD PARAMETERS
# optional list of folder paths e.g. "[My Folder/My Subfolder]"
# if specified, will only index files in these folders
folder_paths: list[str] | None = None,
include_shared: bool | None = None,
follow_shortcuts: bool | None = None,
only_org_public: bool | None = None,
continue_on_failure: bool | None = None,
batch_size: int = INDEX_BATCH_SIZE,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
only_org_public: bool = GOOGLE_DRIVE_ONLY_ORG_PUBLIC,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
) -> None:
# Check for old input parameters
if (
folder_paths is not None
or include_shared is not None
or follow_shortcuts is not None
or only_org_public is not None
or continue_on_failure is not None
):
logger.exception(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
raise ValueError(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
if (
not include_shared_drives
and not include_my_drives
and not shared_folder_urls
):
raise ValueError(
"At least one of include_shared_drives, include_my_drives,"
" or shared_folder_urls must be true"
)
self.folder_paths = folder_paths or []
self.batch_size = batch_size
self.include_shared = include_shared
self.follow_shortcuts = follow_shortcuts
self.only_org_public = only_org_public
self.continue_on_failure = continue_on_failure
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
self.include_shared_drives = include_shared_drives
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
self._requested_shared_drive_ids = set(
_extract_ids_from_urls(shared_drive_url_list)
)
@staticmethod
def _process_folder_paths(
service: discovery.Resource,
folder_paths: list[str],
include_shared: bool,
follow_shortcuts: bool,
) -> list[str]:
"""['Folder/Sub Folder'] -> ['<FOLDER_ID>']"""
folder_ids: list[str] = []
for path in folder_paths:
folder_names = path.split("/")
parent_id = "root"
for folder_name in folder_names:
found_parent_id = _get_folder_id(
service=service,
parent_id=parent_id,
folder_name=folder_name,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
)
if found_parent_id is None:
raise ValueError(
(
f"Folder '{folder_name}' in path '{path}' "
"not found in Google Drive"
)
)
parent_id = found_parent_id
folder_ids.append(parent_id)
self.include_my_drives = include_my_drives
self._requested_my_drive_emails = set(
_extract_str_list_from_comma_str(my_drive_emails)
)
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
self._requested_folder_ids = set(_extract_ids_from_urls(shared_folder_url_list))
self._primary_admin_email: str | None = None
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._retrieved_ids: set[str] = set()
@property
def primary_admin_email(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email
@property
def google_domain(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email.split("@")[-1]
@property
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
if self._creds is None:
raise RuntimeError(
"Creds missing, "
"should not call this property "
"before calling load_credentials"
)
return self._creds
return folder_ids
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._primary_admin_email = primary_admin_email
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,
source=DocumentSource.GOOGLE_DRIVE,
)
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds, new_creds_dict = get_google_drive_creds(credentials)
self.creds = creds
return new_creds_dict
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._retrieved_ids.add(folder_id)
def _get_all_user_emails(self, admins_only: bool) -> list[str]:
admin_service = get_admin_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
query = "isAdmin=true" if admins_only else "isAdmin=false"
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
query=query,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
def _get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
all_drive_ids = set()
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
useDomainAdminAccess=True,
fields="drives(id)",
):
all_drive_ids.add(drive["id"])
return all_drive_ids
def _initialize_all_class_variables(self) -> None:
# Get all user emails
# Get admins first becuase they are more likely to have access to the most files
user_emails = [self.primary_admin_email]
for admins_only in [True, False]:
for email in self._get_all_user_emails(admins_only=admins_only):
if email not in user_emails:
user_emails.append(email)
self._all_org_emails = user_emails
self._all_drive_ids: set[str] = self._get_all_drive_ids()
# remove drive ids from the folder ids because they are queried differently
self._requested_folder_ids -= self._all_drive_ids
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - self._all_drive_ids
if invalid_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
)
logger.warning("Checking for folder access instead...")
self._requested_folder_ids.update(invalid_drive_ids)
if not self.include_shared_drives:
self._requested_shared_drive_ids = set()
elif not self._requested_shared_drive_ids:
self._requested_shared_drive_ids = self._all_drive_ids
def _impersonate_user_for_retrieval(
self,
user_email: str,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, user_email)
if self.include_my_drives and (
not self._requested_my_drive_emails
or user_email in self._requested_my_drive_emails
):
yield from get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
start=start,
end=end,
)
remaining_drive_ids = self._requested_shared_drive_ids - self._retrieved_ids
for drive_id in remaining_drive_ids:
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
is_slim=is_slim,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
remaining_folders = self._requested_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
traversed_parent_ids=self._retrieved_ids,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
def _fetch_drive_items(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
self._initialize_all_class_variables()
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_email = {
executor.submit(
self._impersonate_user_for_retrieval, email, is_slim, start, end
): email
for email in self._all_org_emails
}
# Yield results as they complete
for future in as_completed(future_to_email):
yield from future.result()
remaining_folders = self._requested_folder_ids - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)
def _extract_docs_from_google_drive(
def _fetch_docs_from_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
# Create a larger process pool for file conversion
convert_func = partial(
_convert_single_file, self.creds, self.primary_admin_email
if self.creds is None:
raise PermissionError("Not logged into Google Drive")
service = discovery.build("drive", "v3", credentials=self.creds)
folder_ids: Sequence[str | None] = self._process_folder_paths(
service, self.folder_paths, self.include_shared, self.follow_shortcuts
)
if not folder_ids:
folder_ids = [None]
# Process files in larger batches
LARGE_BATCH_SIZE = self.batch_size * 4
files_to_process = []
# Gather the files into batches to be processed in parallel
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
files_to_process.append(file)
if len(files_to_process) >= LARGE_BATCH_SIZE:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
file_batches = chain(
*[
get_all_files_batched(
service=service,
continue_on_failure=self.continue_on_failure,
include_shared=self.include_shared,
follow_shortcuts=self.follow_shortcuts,
batch_size=self.batch_size,
time_range_start=start,
time_range_end=end,
folder_id=folder_id,
traverse_subfolders=True,
)
files_to_process = []
for folder_id in folder_ids
]
)
for files_batch in file_batches:
doc_batch = []
for file in files_batch:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
continue
# Process any remaining files
if files_to_process:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
)
if self.only_org_public:
if "permissions" not in file:
continue
if not any(
permission["type"] == "domain"
for permission in file["permissions"]
):
continue
try:
text_contents = extract_text(file, service) or ""
except HttpError as e:
reason = (
e.error_details[0]["reason"]
if e.error_details
else e.reason
)
message = (
e.error_details[0]["message"]
if e.error_details
else e.reason
)
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
continue
raise
doc_batch.append(
Document(
id=file["webViewLink"],
sections=[
Section(link=file["webViewLink"], text=text_contents)
],
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(
file["modifiedTime"]
).astimezone(timezone.utc),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
)
except Exception as e:
if not self.continue_on_failure:
raise e
logger.exception(
"Ran into exception when pulling a file from Google Drive"
)
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
try:
yield from self._extract_docs_from_google_drive()
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
yield from self._fetch_docs_from_drive()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
try:
yield from self._extract_docs_from_google_drive(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
# need to subtract 10 minutes from start time to account for modifiedTime
# propogation if a document is modified, it takes some time for the API to
# reflect these changes if we do not have an offset, then we may "miss" the
# update when polling
yield from self._fetch_docs_from_drive(start, end)
def _extract_slim_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
slim_batch = []
for file in self._fetch_drive_items(
is_slim=True,
start=start,
end=end,
):
if doc := build_slim_document(file):
slim_batch.append(doc)
if len(slim_batch) >= SLIM_BATCH_SIZE:
yield slim_batch
slim_batch = []
yield slim_batch
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._extract_slim_docs_from_google_drive(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
if __name__ == "__main__":
import json
import os
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
if not service_account_json_path:
raise ValueError(
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
)
with open(service_account_json_path) as f:
creds = json.load(f)
credentials_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: json.dumps(creds),
}
delegated_user = os.environ.get("GOOGLE_DRIVE_DELEGATED_USER")
if delegated_user:
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
connector = GoogleDriveConnector(include_shared=True, follow_shortcuts=True)
connector.load_credentials(credentials_dict)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print(document_batch)
break

View File

@@ -0,0 +1,229 @@
import json
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
from urllib.parse import urlparse
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import BASE_SCOPES
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
from danswer.utils.logger import setup_logger
logger = setup_logger()
def build_gdrive_scopes() -> list[str]:
base_scopes: list[str] = BASE_SCOPES
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
if ENTERPRISE_EDITION_ENABLED:
return base_scopes + permissions_scopes + groups_scopes
return base_scopes + permissions_scopes
def _build_frontend_google_drive_redirect() -> str:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
def get_google_drive_creds_for_authorized_user(
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logger.notice("Refreshed Google Drive tokens.")
return creds
except Exception as e:
logger.exception(f"Failed to refresh google drive access token due to: {e}")
return None
return None
def _get_google_drive_creds_for_service_account(
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> ServiceAccountCredentials | None:
service_account_key = json.loads(service_account_key_json_str)
creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=scopes
)
if not creds.valid or not creds.expired:
creds.refresh(Request())
return creds if creds.valid else None
def get_google_drive_creds(
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
oauth_creds = get_google_drive_creds_for_authorized_user(
token_json_str=access_token_json_str, scopes=scopes
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
service_creds = _get_google_drive_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str,
scopes=scopes,
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
service_creds = (
service_creds.with_subject(delegated_user_email)
if service_creds
else None
)
creds: ServiceAccountCredentials | OAuthCredentials | None = (
oauth_creds or service_creds
)
if creds is None:
raise PermissionError(
"Unable to access Google Drive - unknown credential structure."
)
return creds, new_creds_dict
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
)
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def update_credential_access_tokens(
auth_code: str,
credential_id: int,
user: User,
db_session: Session,
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
token_json_str = creds.to_json()
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
return None
return creds
def build_service_account_creds(
source: DocumentSource,
delegated_user_email: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key()
credential_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
}
if delegated_user_email:
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
)
def get_google_app_cred() -> GoogleAppCredentials:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
get_kv_store().store(KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True)
def delete_google_app_cred() -> None:
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
def get_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_kv_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_service_account_key() -> None:
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)

View File

@@ -1,4 +1,7 @@
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]

View File

@@ -1,197 +0,0 @@
import io
from datetime import datetime
from datetime import timezone
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from danswer.connectors.google_drive.models import GDriveMimeType
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_drive.section_extraction import get_document_sections
from danswer.connectors.google_utils.resources import GoogleDocsService
from danswer.connectors.google_utils.resources import GoogleDriveService
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
def _extract_sections_basic(
file: dict[str, str], service: GoogleDriveService
) -> list[Section]:
mime_type = file["mimeType"]
link = file["webViewLink"]
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
text = (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
return [Section(link=link, text=text)]
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return [
Section(
link=link,
text=service.files()
.get_media(fileId=file["id"])
.execute()
.decode("utf-8"),
)
]
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return [
Section(
link=link,
text=unstructured_to_text(
file=io.BytesIO(response),
file_name=file.get("name", file["id"]),
),
)
]
if mime_type == GDriveMimeType.WORD_DOC.value:
return [
Section(link=link, text=docx_to_text(file=io.BytesIO(response)))
]
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return [Section(link=link, text=text)]
elif mime_type == GDriveMimeType.POWERPOINT.value:
return [
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
]
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
except Exception:
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
) -> Document | None:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
return None
# Skip files that are folders
if file.get("mimeType") == DRIVE_FOLDER_TYPE:
logger.info("Ignoring Drive Folder Filetype")
return None
sections: list[Section] = []
# Special handling for Google Docs to preserve structure, link
# to headers
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
sections = get_document_sections(docs_service, file["id"])
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'."
" Falling back to basic extraction."
)
# NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc
if not sections:
try:
# For all other file types just extract the text
sections = _extract_sections_basic(file, drive_service)
except HttpError as e:
reason = e.error_details[0]["reason"] if e.error_details else e.reason
message = e.error_details[0]["message"] if e.error_details else e.reason
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
return None
raise
if not sections:
return None
return Document(
id=file["webViewLink"],
sections=sections,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
),
metadata={}
if any(section.text for section in sections)
else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
except Exception as e:
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise e
logger.exception("Ran into exception when pulling a file from Google Drive")
return None
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
# Skip files that are folders or shortcuts
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
return SlimDocument(
id=file["webViewLink"],
perm_sync_data={
"doc_id": file.get("id"),
"permissions": file.get("permissions", []),
"permission_ids": file.get("permissionIds", []),
"name": file.get("name"),
"owner_email": file.get("owners", [{}])[0].get("emailAddress"),
},
)

View File

@@ -1,222 +0,0 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from typing import Any
from googleapiclient.discovery import Resource # type: ignore
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.utils.logger import setup_logger
logger = setup_logger()
FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
"shortcutDetails, owners(emailAddress))"
)
SLIM_FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
"permissionIds, webViewLink, owners(emailAddress))"
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
def _generate_time_range_filter(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
time_range_filter = ""
if start is not None:
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
time_range_filter += f" and modifiedTime >= '{time_start}'"
if end is not None:
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
time_range_filter += f" and modifiedTime <= '{time_stop}'"
return time_range_filter
def _get_folders_in_parent(
service: Resource,
parent_id: str | None = None,
) -> Iterator[GoogleDriveFileType]:
# Follow shortcuts to folders
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
query += " and trashed = false"
if parent_id:
query += f" and '{parent_id}' in parents"
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=FOLDER_FIELDS,
q=query,
):
yield file
def _get_files_in_parent(
service: Resource,
parent_id: str,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
is_slim: bool = False,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
yield file
def crawl_folders_for_files(
service: Resource,
parent_id: str,
traversed_parent_ids: set[str],
update_traversed_ids_func: Callable[[str], None],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
"""
This function starts crawling from any folder. It is slower though.
"""
if parent_id in traversed_parent_ids:
logger.info(f"Skipping subfolder since already traversed: {parent_id}")
return
found_files = False
for file in _get_files_in_parent(
service=service,
start=start,
end=end,
parent_id=parent_id,
):
found_files = True
yield file
if found_files:
update_traversed_ids_func(parent_id)
for subfolder in _get_folders_in_parent(
service=service,
parent_id=parent_id,
):
logger.info("Fetching all files in subfolder: " + subfolder["name"])
yield from crawl_folders_for_files(
service=service,
parent_id=subfolder["id"],
traversed_parent_ids=traversed_parent_ids,
update_traversed_ids_func=update_traversed_ids_func,
start=start,
end=end,
)
def get_files_in_shared_drive(
service: Resource,
drive_id: str,
is_slim: bool = False,
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields="nextPageToken, files(id)",
q=query,
):
update_traversed_ids_func(file["id"])
found_folders = True
if found_folders:
update_traversed_ids_func(drive_id)
# Get all files in the shared drive
query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
)
def get_all_files_in_my_drive(
service: Any,
update_traversed_ids_func: Callable,
is_slim: bool = False,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
query = "trashed = false and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
update_traversed_ids_func(file["id"])
found_folders = True
if found_folders:
update_traversed_ids_func(get_root_folder_id(service))
# Then get the files
query = "trashed = false and 'me' in owners"
query += _generate_time_range_filter(start, end)
fields = "files(id, name, mimeType, webViewLink, modifiedTime, createdTime)"
if not is_slim:
fields += ", files(permissions, permissionIds, owners)"
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
)
# Just in case we need to get the root folder id
def get_root_folder_id(service: Resource) -> str:
# we dont paginate here because there is only one root folder per user
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
return service.files().get(fileId="root", fields="id").execute()["id"]

View File

@@ -1,18 +0,0 @@
from enum import Enum
from typing import Any
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
GoogleDriveFileType = dict[str, Any]

View File

@@ -1,105 +0,0 @@
from typing import Any
from pydantic import BaseModel
from danswer.connectors.google_utils.resources import GoogleDocsService
from danswer.connectors.models import Section
class CurrentHeading(BaseModel):
id: str
text: str
def _build_gdoc_section_link(doc_id: str, heading_id: str) -> str:
"""Builds a Google Doc link that jumps to a specific heading"""
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
# @Chris
return (
f"https://docs.google.com/document/d/{doc_id}/edit?tab=t.0#heading={heading_id}"
)
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
"""Extracts the id from a heading paragraph element"""
return paragraph["paragraphStyle"]["headingId"]
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
"""Extracts the text content from a paragraph element"""
text_elements = []
for element in paragraph.get("elements", []):
if "textRun" in element:
text_elements.append(element["textRun"].get("content", ""))
return "".join(text_elements)
def get_document_sections(
docs_service: GoogleDocsService,
doc_id: str,
) -> list[Section]:
"""Extracts sections from a Google Doc, including their headings and content"""
# Fetch the document structure
doc = docs_service.documents().get(documentId=doc_id).execute()
# Get the content
content = doc.get("body", {}).get("content", [])
sections: list[Section] = []
current_section: list[str] = []
current_heading: CurrentHeading | None = None
for element in content:
if "paragraph" not in element:
continue
paragraph = element["paragraph"]
# Check if this is a heading
if (
"paragraphStyle" in paragraph
and "namedStyleType" in paragraph["paragraphStyle"]
):
style = paragraph["paragraphStyle"]["namedStyleType"]
is_heading = style.startswith("HEADING_")
is_title = style.startswith("TITLE")
if is_heading or is_title:
# If we were building a previous section, add it to sections list
if current_heading is not None and current_section:
heading_text = current_heading.text
section_text = f"{heading_text}\n" + "\n".join(current_section)
sections.append(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
)
current_section = []
# Start new heading
heading_id = _extract_id_from_heading(paragraph)
heading_text = _extract_text_from_paragraph(paragraph)
current_heading = CurrentHeading(
id=heading_id,
text=heading_text,
)
continue
# Add content to current section
if current_heading is not None:
text = _extract_text_from_paragraph(paragraph)
if text.strip():
current_section.append(text)
# Don't forget to add the last section
if current_heading is not None and current_section:
section_text = f"{current_heading.text}\n" + "\n".join(current_section)
sections.append(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
)
return sections

View File

@@ -1,107 +0,0 @@
import json
from typing import cast
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
GOOGLE_SCOPES,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_google_oauth_creds(
token_json_str: str, source: DocumentSource
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(
info=creds_json,
scopes=GOOGLE_SCOPES[source],
)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logger.notice("Refreshed Google Drive tokens.")
return creds
except Exception:
logger.exception("Failed to refresh google drive access token due to:")
return None
return None
def get_google_creds(
credentials: dict[str, str],
source: DocumentSource,
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
# OAUTH
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
oauth_creds = get_google_oauth_creds(
token_json_str=access_token_json_str, source=source
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
],
}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
# SERVICE ACCOUNT
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
service_account_key = json.loads(service_account_key_json_str)
service_creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=GOOGLE_SCOPES[source]
)
if not service_creds.valid or not service_creds.expired:
service_creds.refresh(Request())
if not service_creds.valid:
raise PermissionError(
f"Unable to access {source} - service account credentials are invalid."
)
creds: ServiceAccountCredentials | OAuthCredentials | None = (
oauth_creds or service_creds
)
if creds is None:
raise PermissionError(
f"Unable to access {source} - unknown credential structure."
)
return creds, new_creds_dict

View File

@@ -1,237 +0,0 @@
import json
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
from urllib.parse import urlparse
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GMAIL_CRED_KEY
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_utils.resources import get_drive_service
from danswer.connectors.google_utils.resources import get_gmail_service
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
GOOGLE_SCOPES,
)
from danswer.connectors.google_utils.shared_constants import (
MISSING_SCOPES_ERROR_STR,
)
from danswer.connectors.google_utils.shared_constants import (
ONYX_SCOPE_INSTRUCTIONS,
)
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
elif source == DocumentSource.GMAIL:
return f"{WEB_DOMAIN}/admin/connectors/gmail/auth/callback"
else:
raise ValueError(f"Unsupported source: {source}")
def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
drive_service = get_drive_service(creds)
user_info = (
drive_service.about()
.get(
fields="user(emailAddress)",
)
.execute()
)
email = user_info.get("user", {}).get("emailAddress")
elif source == DocumentSource.GMAIL:
gmail_service = get_gmail_service(creds)
user_info = (
gmail_service.users()
.getProfile(
userId="me",
fields="emailAddress",
)
.execute()
)
email = user_info.get("emailAddress")
else:
raise ValueError(f"Unsupported source: {source}")
return email
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
)
def update_credential_access_tokens(
auth_code: str,
credential_id: int,
user: User,
db_session: Session,
source: DocumentSource,
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred(source)
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=GOOGLE_SCOPES[source],
redirect_uri=_build_frontend_google_drive_redirect(source),
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
token_json_str = creds.to_json()
# Get user email from Google API so we know who
# the primary admin is for this connector
try:
email = _get_current_oauth_user(creds, source)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
return None
return creds
def build_service_account_creds(
source: DocumentSource,
primary_admin_email: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key(source=source)
credential_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
}
if primary_admin_email:
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=source,
)
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=GOOGLE_SCOPES[source],
redirect_uri=_build_frontend_google_drive_redirect(source),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_cred(
app_credentials: GoogleAppCredentials, source: DocumentSource
) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
else:
raise ValueError(f"Unsupported source: {source}")
def delete_google_app_cred(source: DocumentSource) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
elif source == DocumentSource.GMAIL:
get_kv_store().delete(KV_GMAIL_CRED_KEY)
else:
raise ValueError(f"Unsupported source: {source}")
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_service_account_key(
service_account_key: GoogleServiceAccountKey, source: DocumentSource
) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
service_account_key.json(),
encrypt=True,
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
else:
raise ValueError(f"Unsupported source: {source}")
def delete_service_account_key(source: DocumentSource) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
elif source == DocumentSource.GMAIL:
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
else:
raise ValueError(f"Unsupported source: {source}")

View File

@@ -1,125 +0,0 @@
import re
import time
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from googleapiclient.errors import HttpError # type: ignore
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
def _execute_with_retry(request: Any) -> Any:
max_attempts = 10
attempt = 1
while attempt < max_attempts:
# Note for reasons unknown, the Google API will sometimes return a 429
# and even after waiting the retry period, it will return another 429.
# It could be due to a few possibilities:
# 1. Other things are also requesting from the Gmail API with the same key
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
# 3. The retry-after has a maximum and we've already hit the limit for the day
# or it's something else...
try:
return request.execute()
except HttpError as error:
attempt += 1
if error.resp.status == 429:
# Attempt to get 'Retry-After' from headers
retry_after = error.resp.get("Retry-After")
if retry_after:
sleep_time = int(retry_after)
else:
# Extract 'Retry after' timestamp from error message
match = re.search(
r"Retry after (\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+Z)",
str(error),
)
if match:
retry_after_timestamp = match.group(1)
retry_after_dt = datetime.strptime(
retry_after_timestamp, "%Y-%m-%dT%H:%M:%S.%fZ"
).replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
sleep_time = max(
int((retry_after_dt - current_time).total_seconds()),
0,
)
else:
logger.error(
f"No Retry-After header or timestamp found in error message: {error}"
)
sleep_time = 60
sleep_time += 3 # Add a buffer to be safe
logger.info(
f"Rate limit exceeded. Attempt {attempt}/{max_attempts}. Sleeping for {sleep_time} seconds."
)
time.sleep(sleep_time)
else:
raise
# If we've exhausted all attempts
raise Exception(f"Failed to execute request after {max_attempts} attempts")
def execute_paginated_retrieval(
retrieval_function: Callable,
list_key: str | None = None,
continue_on_404_or_403: bool = False,
**kwargs: Any,
) -> Iterator[GoogleDriveFileType]:
"""Execute a paginated retrieval from Google Drive API
Args:
retrieval_function: The specific list function to call (e.g., service.files().list)
**kwargs: Arguments to pass to the list function
"""
next_page_token = ""
while next_page_token is not None:
request_kwargs = kwargs.copy()
if next_page_token:
request_kwargs["pageToken"] = next_page_token
try:
results = retrieval_function(**request_kwargs).execute()
except HttpError as e:
if e.resp.status >= 500:
results = add_retries(
lambda: retrieval_function(**request_kwargs).execute()
)()
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.warning(f"Error executing request: {e}")
results = {}
else:
raise e
elif e.resp.status == 429:
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
else:
logger.exception("Error executing request:")
raise e
next_page_token = results.get("nextPageToken")
if list_key:
for item in results.get(list_key, []):
yield item
else:
yield results

View File

@@ -1,63 +0,0 @@
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore
class GoogleDriveService(Resource):
pass
class GoogleDocsService(Resource):
pass
class AdminService(Resource):
pass
class GmailService(Resource):
pass
def _get_google_service(
service_name: str,
service_version: str,
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
if isinstance(creds, ServiceAccountCredentials):
creds = creds.with_subject(user_email)
service = build(service_name, service_version, credentials=creds)
elif isinstance(creds, OAuthCredentials):
service = build(service_name, service_version, credentials=creds)
return service
def get_google_docs_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDocsService:
return _get_google_service("docs", "v1", creds, user_email)
def get_drive_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService:
return _get_google_service("drive", "v3", creds, user_email)
def get_admin_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> AdminService:
return _get_google_service("admin", "directory_v1", creds, user_email)
def get_gmail_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GmailService:
return _get_google_service("gmail", "v1", creds, user_email)

View File

@@ -1,40 +0,0 @@
from danswer.configs.constants import DocumentSource
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
# this is counted under `/auth/drive.readonly`
GOOGLE_SCOPES = {
DocumentSource.GOOGLE_DRIVE: [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
],
DocumentSource.GMAIL: [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
],
}
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
USER_FIELDS = "nextPageToken, users(primaryEmail)"
# Error message substrings
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
# Documentation and error messages
SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview"
ONYX_SCOPE_INSTRUCTIONS = (
"You have upgraded Danswer without updating the Google Auth scopes. "
f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}"
)
# This is the maximum number of threads that can be retrieved at once
SLIM_BATCH_SIZE = 500

View File

@@ -56,11 +56,7 @@ class PollConnector(BaseConnector):
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
raise NotImplementedError

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import builtins
import functools
import itertools
import tempfile
from typing import Any
from unittest import mock
from urllib.parse import urlparse
@@ -19,8 +18,6 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
@mock.patch.object(
builtins, "print", lambda *args: logger.info("\t".join(map(str, args)))

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import datetime
import itertools
import tempfile
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
@@ -26,8 +25,6 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
def pywikibot_timestamp_to_utc_datetime(
timestamp: pywikibot.time.Timestamp,
@@ -124,6 +121,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
# short names can only have ascii letters and digits
self.family = family_class_dispatch(hostname, "WikipediaConnector")()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [

View File

@@ -251,11 +251,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
doc_metadata_list: list[SlimDocument] = []

View File

@@ -391,11 +391,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
self.client = WebClient(token=bot_token)
return None
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")

View File

@@ -1,7 +1,10 @@
from collections.abc import Iterator
from typing import Any
import requests
from retry import retry
from zenpy import Zenpy # type: ignore
from zenpy.lib.api_objects import Ticket # type: ignore
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
@@ -17,244 +20,43 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.retry_wrapper import retry_builder
MAX_PAGE_SIZE = 30 # Zendesk API maximum
class ZendeskCredentialsNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__(
"Zendesk Credentials are not set up, was load_credentials called?"
)
class ZendeskClient:
def __init__(self, subdomain: str, email: str, token: str):
self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
self.auth = (f"{email}/token", token)
@retry_builder()
def make_request(self, endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
response = requests.get(
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
)
response.raise_for_status()
return response.json()
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
content_tags: dict[str, str] = {}
params = {"page[size]": MAX_PAGE_SIZE}
try:
while True:
data = client.make_request("guide/content_tags", params)
for tag in data.get("records", []):
content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
return content_tags
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
def _get_articles(
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
) -> Iterator[dict[str, Any]]:
params = (
{"start_time": start_time, "page[size]": page_size}
if start_time
else {"page[size]": page_size}
def _article_to_document(article: Article, content_tags: dict[str, str]) -> Document:
author = BasicExpertInfo(
display_name=article.author.name, email=article.author.email
)
update_time = time_str_to_utc(article.updated_at)
while True:
data = client.make_request("help_center/articles", params)
for article in data["articles"]:
yield article
if not data.get("meta", {}).get("has_more"):
break
params["page[after]"] = data["meta"]["after_cursor"]
def _get_tickets(
client: ZendeskClient, start_time: int | None = None
) -> Iterator[dict[str, Any]]:
params = {"start_time": start_time} if start_time else {"start_time": 0}
while True:
data = client.make_request("incremental/tickets.json", params)
for ticket in data["tickets"]:
yield ticket
if not data.get("end_of_stream", False):
params["start_time"] = data["end_time"]
else:
break
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
def _article_to_document(
article: dict[str, Any],
content_tags: dict[str, str],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
author_id = article.get("author_id")
if not author_id:
author = None
else:
author = (
author_map.get(author_id)
if author_id in author_map
else _fetch_author(client, author_id)
)
new_author_mapping = {author_id: author} if author_id and author else None
updated_at = article.get("updated_at")
update_time = time_str_to_utc(updated_at) if updated_at else None
# Build metadata
# build metadata
metadata: dict[str, str | list[str]] = {
"labels": [str(label) for label in article.get("label_names", []) if label],
"labels": [str(label) for label in article.label_names if label],
"content_tags": [
content_tags[tag_id]
for tag_id in article.get("content_tag_ids", [])
for tag_id in article.content_tag_ids
if tag_id in content_tags
],
}
# Remove empty values
# remove empty values
metadata = {k: v for k, v in metadata.items() if v}
return new_author_mapping, Document(
id=f"article:{article['id']}",
return Document(
id=f"article:{article.id}",
sections=[
Section(
link=article.get("html_url"),
text=parse_html_page_basic(article["body"]),
)
Section(link=article.html_url, text=parse_html_page_basic(article.body))
],
source=DocumentSource.ZENDESK,
semantic_identifier=article["title"],
semantic_identifier=article.title,
doc_updated_at=update_time,
primary_owners=[author] if author else None,
primary_owners=[author],
metadata=metadata,
)
def _get_comment_text(
comment: dict[str, Any],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
) -> tuple[dict[str, BasicExpertInfo] | None, str]:
author_id = comment.get("author_id")
if not author_id:
author = None
else:
author = (
author_map.get(author_id)
if author_id in author_map
else _fetch_author(client, author_id)
)
new_author_mapping = {author_id: author} if author_id and author else None
comment_text = f"Comment{' by ' + author.display_name if author and author.display_name else ''}"
comment_text += f"{' at ' + comment['created_at'] if comment.get('created_at') else ''}:\n{comment['body']}"
return new_author_mapping, comment_text
def _ticket_to_document(
ticket: dict[str, Any],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
default_subdomain: str,
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
submitter_id = ticket.get("submitter")
if not submitter_id:
submitter = None
else:
submitter = (
author_map.get(submitter_id)
if submitter_id in author_map
else _fetch_author(client, submitter_id)
)
new_author_mapping = (
{submitter_id: submitter} if submitter_id and submitter else None
)
updated_at = ticket.get("updated_at")
update_time = time_str_to_utc(updated_at) if updated_at else None
metadata: dict[str, str | list[str]] = {}
if status := ticket.get("status"):
metadata["status"] = status
if priority := ticket.get("priority"):
metadata["priority"] = priority
if tags := ticket.get("tags"):
metadata["tags"] = tags
if ticket_type := ticket.get("type"):
metadata["ticket_type"] = ticket_type
# Fetch comments for the ticket
comments_data = client.make_request(f"tickets/{ticket.get('id')}/comments", {})
comments = comments_data.get("comments", [])
comment_texts = []
for comment in comments:
new_author_mapping, comment_text = _get_comment_text(
comment, author_map, client
)
if new_author_mapping:
author_map.update(new_author_mapping)
comment_texts.append(comment_text)
comments_text = "\n\n".join(comment_texts)
subject = ticket.get("subject")
full_text = f"Ticket Subject:\n{subject}\n\nComments:\n{comments_text}"
ticket_url = ticket.get("url")
subdomain = (
ticket_url.split("//")[1].split(".zendesk.com")[0]
if ticket_url
else default_subdomain
)
ticket_display_url = (
f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.get('id')}"
)
return new_author_mapping, Document(
id=f"zendesk_ticket_{ticket['id']}",
sections=[Section(link=ticket_display_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
doc_updated_at=update_time,
primary_owners=[submitter] if submitter else None,
metadata=metadata,
)
class ZendeskClientNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__("Zendesk Client is not set up, was load_credentials called?")
class ZendeskConnector(LoadConnector, PollConnector):
@@ -264,10 +66,44 @@ class ZendeskConnector(LoadConnector, PollConnector):
content_type: str = "articles",
) -> None:
self.batch_size = batch_size
self.content_type = content_type
self.subdomain = ""
# Fetch all tags ahead of time
self.zendesk_client: Zenpy | None = None
self.content_tags: dict[str, str] = {}
self.content_type = content_type
@retry(tries=3, delay=2, backoff=2)
def _set_content_tags(
self, subdomain: str, email: str, token: str, page_size: int = 30
) -> None:
# Construct the base URL
base_url = f"https://{subdomain}.zendesk.com/api/v2/guide/content_tags"
# Set up authentication
auth = (f"{email}/token", token)
# Set up pagination parameters
params = {"page[size]": page_size}
try:
while True:
# Make the GET request
response = requests.get(base_url, auth=auth, params=params)
# Check if the request was successful
if response.status_code == 200:
data = response.json()
content_tag_list = data.get("records", [])
for tag in content_tag_list:
self.content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
else:
raise Exception(f"Error: {response.status_code}\n{response.text}")
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# Subdomain is actually the whole URL
@@ -276,23 +112,87 @@ class ZendeskConnector(LoadConnector, PollConnector):
.replace("https://", "")
.split(".zendesk.com")[0]
)
self.subdomain = subdomain
self.client = ZendeskClient(
subdomain, credentials["zendesk_email"], credentials["zendesk_token"]
self.zendesk_client = Zenpy(
subdomain=subdomain,
email=credentials["zendesk_email"],
token=credentials["zendesk_token"],
)
self._set_content_tags(
subdomain,
credentials["zendesk_email"],
credentials["zendesk_token"],
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
return self.poll_source(None, None)
def _ticket_to_document(self, ticket: Ticket) -> Document:
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
owner = None
if ticket.requester and ticket.requester.name and ticket.requester.email:
owner = [
BasicExpertInfo(
display_name=ticket.requester.name, email=ticket.requester.email
)
]
update_time = time_str_to_utc(ticket.updated_at) if ticket.updated_at else None
metadata: dict[str, str | list[str]] = {}
if ticket.status is not None:
metadata["status"] = ticket.status
if ticket.priority is not None:
metadata["priority"] = ticket.priority
if ticket.tags:
metadata["tags"] = ticket.tags
if ticket.type is not None:
metadata["ticket_type"] = ticket.type
# Fetch comments for the ticket
comments = self.zendesk_client.tickets.comments(ticket=ticket)
# Combine all comments into a single text
comments_text = "\n\n".join(
[
f"Comment{f' by {comment.author.name}' if comment.author and comment.author.name else ''}"
f"{f' at {comment.created_at}' if comment.created_at else ''}:\n{comment.body}"
for comment in comments
if comment.body
]
)
# Combine ticket description and comments
description = (
ticket.description
if hasattr(ticket, "description") and ticket.description
else ""
)
full_text = f"Ticket Description:\n{description}\n\nComments:\n{comments_text}"
# Extract subdomain from ticket.url
subdomain = ticket.url.split("//")[1].split(".zendesk.com")[0]
# Build the html url for the ticket
ticket_url = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.id}"
return Document(
id=f"zendesk_ticket_{ticket.id}",
sections=[Section(link=ticket_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket.id}: {ticket.subject or 'No Subject'}",
doc_updated_at=update_time,
primary_owners=owner,
metadata=metadata,
)
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
self.content_tags = _get_content_tag_mapping(self.client)
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
if self.content_type == "articles":
yield from self._poll_articles(start)
@@ -304,30 +204,26 @@ class ZendeskConnector(LoadConnector, PollConnector):
def _poll_articles(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
articles = _get_articles(self.client, start_time=int(start) if start else None)
# This one is built on the fly as there may be more many more authors than tags
author_map: dict[str, BasicExpertInfo] = {}
articles = (
self.zendesk_client.help_center.articles(cursor_pagination=True) # type: ignore
if start is None
else self.zendesk_client.help_center.articles.incremental( # type: ignore
start_time=int(start)
)
)
doc_batch = []
for article in articles:
if (
article.get("body") is None
or article.get("draft")
article.body is None
or article.draft
or any(
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
for label in article.get("label_names", [])
for label in article.label_names
)
):
continue
new_author_map, documents = _article_to_document(
article, self.content_tags, author_map, self.client
)
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
doc_batch.append(_article_to_document(article, self.content_tags))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
@@ -338,14 +234,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
def _poll_tickets(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
author_map: dict[str, BasicExpertInfo] = {}
ticket_generator = _get_tickets(
self.client, start_time=int(start) if start else None
)
ticket_generator = self.zendesk_client.tickets.incremental(start_time=start)
while True:
doc_batch = []
@@ -354,20 +246,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
ticket = next(ticket_generator)
# Check if the ticket status is deleted and skip it if so
if ticket.get("status") == "deleted":
if ticket.status == "deleted":
continue
new_author_map, documents = _ticket_to_document(
ticket=ticket,
author_map=author_map,
client=self.client,
default_subdomain=self.subdomain,
)
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
doc_batch.append(self._ticket_to_document(ticket))
if len(doc_batch) >= self.batch_size:
yield doc_batch
@@ -385,6 +267,7 @@ class ZendeskConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
import os
import time
connector = ZendeskConnector()

View File

@@ -1,5 +1,3 @@
import os
from sqlalchemy.orm import Session
from danswer.db.models import SlackBotConfig
@@ -50,16 +48,3 @@ def validate_channel_names(
)
return cleaned_channel_names
# Scaling configurations for multi-tenant Slack bot handling
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
TENANT_HEARTBEAT_INTERVAL = (
60 # How often pods send heartbeats to indicate they are still processing a tenant
)
TENANT_HEARTBEAT_EXPIRATION = 180 # How long before a tenant's heartbeat expires, allowing other pods to take over
TENANT_ACQUISITION_INTERVAL = (
60 # How often pods attempt to acquire unprocessed tenants
)
MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))

View File

@@ -1,34 +1,18 @@
import asyncio
import os
import signal
import sys
import threading
import time
from threading import Event
from types import FrameType
from typing import Any
from typing import cast
from typing import Dict
from typing import Set
from prometheus_client import Gauge
from prometheus_client import start_http_server
from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.config import MAX_TENANTS_PER_POD
from danswer.danswerbot.slack.config import TENANT_ACQUISITION_INTERVAL
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_EXPIRATION
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_INTERVAL
from danswer.danswerbot.slack.config import TENANT_LOCK_EXPIRATION
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
@@ -62,7 +46,6 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.db.search_settings import get_current_search_settings
@@ -70,23 +53,17 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.redis.redis_pool import get_redis_client
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
logger = setup_logger()
# Prometheus metric for HPA
active_tenants_gauge = Gauge(
"active_tenants", "Number of active tenants handled by this pod"
)
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
# to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down
# the cause.
@@ -100,213 +77,10 @@ _SLACK_GREETINGS_TO_IGNORE = {
":wave:",
}
# This is always (currently) the user id of Slack's official slackbot
# this is always (currently) the user id of Slack's official slackbot
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
class SlackbotHandler:
def __init__(self) -> None:
logger.info("Initializing SlackbotHandler")
self.tenant_ids: Set[str | None] = set()
self.socket_clients: Dict[str | None, TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[str | None, SlackBotTokens] = {}
self.running = True
self.pod_id = self.get_pod_id()
self._shutdown_event = Event()
logger.info(f"Pod ID: {self.pod_id}")
# Set up signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, self.shutdown)
signal.signal(signal.SIGINT, self.shutdown)
logger.info("Signal handlers registered")
# Start the Prometheus metrics server
logger.info("Starting Prometheus metrics server")
start_http_server(8000)
logger.info("Prometheus metrics server started")
# Start background threads
logger.info("Starting background threads")
self.acquire_thread = threading.Thread(
target=self.acquire_tenants_loop, daemon=True
)
self.heartbeat_thread = threading.Thread(
target=self.heartbeat_loop, daemon=True
)
self.acquire_thread.start()
self.heartbeat_thread.start()
logger.info("Background threads started")
def get_pod_id(self) -> str:
pod_id = os.environ.get("HOSTNAME", "unknown_pod")
logger.info(f"Retrieved pod ID: {pod_id}")
return pod_id
def acquire_tenants_loop(self) -> None:
while not self._shutdown_event.is_set():
try:
self.acquire_tenants()
active_tenants_gauge.set(len(self.tenant_ids))
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
except Exception as e:
logger.exception(f"Error in Slack acquisition: {e}")
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
def heartbeat_loop(self) -> None:
while not self._shutdown_event.is_set():
try:
self.send_heartbeats()
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids()
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
for tenant_id in tenant_ids:
if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
logger.info(
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants"
)
break
redis_client = get_redis_client(tenant_id=tenant_id)
pod_id = self.pod_id
acquired = redis_client.set(
DanswerRedisLocks.SLACK_BOT_LOCK,
pod_id,
nx=True,
ex=TENANT_LOCK_EXPIRATION,
)
if not acquired:
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
continue
logger.debug(f"Acquired lock for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
try:
with get_session_with_tenant(tenant_id) as db_session:
try:
logger.debug(
f"Setting tenant ID context variable for tenant {tenant_id}"
)
slack_bot_tokens = fetch_tokens()
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
logger.debug(
f"Reset tenant ID context variable for tenant {tenant_id}"
)
if not slack_bot_tokens:
logger.debug(
f"No Slack bot token found for tenant {tenant_id}"
)
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
continue
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def send_heartbeats(self) -> None:
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
redis_client = get_redis_client(tenant_id=tenant_id)
heartbeat_key = (
f"{DanswerRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
)
redis_client.set(
heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION
)
def start_socket_client(
self, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
) -> None:
logger.info(f"Starting socket client for tenant {tenant_id}")
socket_client = _get_socket_client(slack_bot_tokens, tenant_id)
# Append the event handler
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info(f"Connecting socket client for tenant {tenant_id}")
socket_client.connect()
self.socket_clients[tenant_id] = socket_client
self.tenant_ids.add(tenant_id)
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for tenant_id, client in self.socket_clients.items():
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
if not self.running:
return
logger.info("Shutting down gracefully")
self.running = False
self._shutdown_event.set()
# Stop all socket clients
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()
# Wait for background threads to finish (with timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)
self.heartbeat_thread.join(timeout=5)
logger.info("Shutdown complete")
sys.exit(0)
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
if req.type == "events_api":
@@ -398,7 +172,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
message_subtype = event.get("subtype")
if message_subtype not in [None, "file_share"]:
channel_specific_logger.info(
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
f"Ignoring message with subtype '{message_subtype}' since is is a special message type"
)
return False
@@ -473,7 +247,7 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
)
query_event_id, _, _ = decompose_action_id(feedback_id)
logger.info(f"Successfully handled QA feedback for event: {query_event_id}")
logger.notice(f"Successfully handled QA feedback for event: {query_event_id}")
def build_request_details(
@@ -495,14 +269,14 @@ def build_request_details(
msg = remove_danswer_bot_tag(msg, client=client.web_client)
if DANSWER_BOT_REPHRASE_MESSAGE:
logger.info(f"Rephrasing Slack message. Original message: {msg}")
logger.notice(f"Rephrasing Slack message. Original message: {msg}")
try:
msg = rephrase_slack_message(msg)
logger.info(f"Rephrased message: {msg}")
logger.notice(f"Rephrased message: {msg}")
except Exception as e:
logger.error(f"Error while trying to rephrase the Slack message: {e}")
else:
logger.info(f"Received Slack message: {msg}")
logger.notice(f"Received Slack message: {msg}")
if tagged:
logger.debug("User tagged DanswerBot")
@@ -703,21 +477,94 @@ def _get_socket_client(
)
def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.notice(f"Listening for messages from Slack {socket_client.tenant_id }...")
socket_client.connect()
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
# the slack bot in your workspace, and then add the bot to any channels you want to
# try and answer questions for. Running this file will setup Danswer to listen to all
# messages in those channels and attempt to answer them. As of now, it will only respond
# to messages sent directly in the channel - it will not respond to messages sent within a
# thread.
#
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue.
if __name__ == "__main__":
# Initialize the tenant handler which will manage tenant connections
logger.info("Starting SlackbotHandler")
tenant_handler = SlackbotHandler()
slack_bot_tokens: dict[str | None, SlackBotTokens] = {}
socket_clients: dict[str | None, TenantSocketModeClient] = {}
set_is_ee_based_on_env_variable()
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
try:
# Keep the main thread alive
while tenant_handler.running:
time.sleep(1)
while True:
try:
tenant_ids = get_all_tenant_ids() # Function to retrieve all tenant IDs
except Exception:
logger.exception("Fatal error in main thread")
tenant_handler.shutdown(None, None)
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
latest_slack_bot_tokens = fetch_tokens()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if (
tenant_id not in slack_bot_tokens
or latest_slack_bot_tokens != slack_bot_tokens[tenant_id]
):
if tenant_id in slack_bot_tokens:
logger.notice(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
# Initial setup for this tenant
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
slack_bot_tokens[tenant_id] = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated
# to avoid + (1) if the user is changing tokens, they are likely okay with some
# "migration downtime" and (2) if a single message is lost it is okay
# as this should be a very rare occurrence
if tenant_id in socket_clients:
socket_clients[tenant_id].close()
socket_client = _get_socket_client(
latest_slack_bot_tokens, tenant_id
)
# Initialize socket client for this tenant. Each tenant has its own
# socket client, allowing for multiple concurrent connections (one
# per tenant) with the tenant ID wrapped in the socket model client.
# Each `connect` stores websocket connection in a separate thread.
_initialize_socket_client(socket_client)
socket_clients[tenant_id] = socket_client
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in socket_clients:
socket_clients[tenant_id].disconnect()
del socket_clients[tenant_id]
del slack_bot_tokens[tenant_id]
# Wait before checking for updates
Event().wait(timeout=60)
except Exception:
logger.exception("An error occurred outside of main event loop")
time.sleep(60)

View File

@@ -14,7 +14,6 @@ from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserRole
from danswer.db.api_key import get_api_key_email_pattern
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.models import AccessToken
@@ -36,16 +35,12 @@ def get_default_admin_user_emails() -> list[str]:
return get_default_admin_user_emails_fn()
def get_total_users_count(db_session: Session) -> int:
def get_total_users(db_session: Session) -> int:
"""
Returns the total number of users in the system.
This is the sum of users and invited users.
"""
user_count = (
db_session.query(User)
.filter(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
.count()
)
user_count = db_session.query(User).count()
invited_users = len(get_invited_users())
return user_count + invited_users

View File

@@ -388,7 +388,7 @@ def get_chat_messages_by_session(
)
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_call))
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@@ -474,7 +474,7 @@ def create_new_chat_message(
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_call: ToolCall | None = None,
tool_calls: list[ToolCall] | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
@@ -494,7 +494,7 @@ def create_new_chat_message(
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_call = tool_call
existing_message.tool_calls = tool_calls if tool_calls else []
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
@@ -513,7 +513,7 @@ def create_new_chat_message(
message_type=message_type,
citations=citations,
files=files,
tool_call=tool_call,
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
@@ -749,13 +749,14 @@ def translate_db_message_to_chat_message_detail(
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)

View File

@@ -25,8 +25,8 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
logger = setup_logger()
@@ -351,11 +351,7 @@ def add_credential_to_connector(
raise HTTPException(status_code=404, detail="Connector does not exist")
if access_type == AccessType.SYNC:
if not fetch_ee_implementation_or_noop(
"danswer.external_permissions.sync_params",
"check_if_valid_sync_source",
noop_return_value=True,
)(connector.source):
if not check_if_valid_sync_source(connector.source):
raise HTTPException(
status_code=400,
detail=f"Connector of type {connector.source} does not support SYNC access type",
@@ -442,10 +438,7 @@ def remove_credential_from_connector(
)
if association is not None:
fetch_ee_implementation_or_noop(
"danswer.db.external_perm",
"delete_user__ext_group_for_cc_pair__no_commit",
)(
delete_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
)

View File

@@ -10,7 +10,10 @@ from sqlalchemy.sql.expression import or_
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_utils.shared_constants import (
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.db.models import ConnectorCredentialPair
@@ -421,15 +424,25 @@ def cleanup_google_drive_credentials(db_session: Session) -> None:
db_session.commit()
def delete_service_account_credentials(
user: User | None, db_session: Session, source: DocumentSource
def delete_gmail_service_account_credentials(
user: User | None, db_session: Session
) -> None:
credentials = fetch_credentials(db_session=db_session, user=user)
for credential in credentials:
if (
credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
and credential.source == source
if credential.credential_json.get(
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
):
db_session.delete(credential)
db_session.commit()
def delete_google_drive_service_account_credentials(
user: User | None, db_session: Session
) -> None:
credentials = fetch_credentials(db_session=db_session, user=user)
for credential in credentials:
if credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY):
db_session.delete(credential)
db_session.commit()

View File

@@ -29,7 +29,6 @@ from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_IDLE_SESSIONS_TIMEOUT
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
@@ -38,10 +37,10 @@ from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -310,12 +309,8 @@ async def get_async_session_with_tenant(
try:
# Set the search_path to the tenant's schema
await session.execute(text(f'SET search_path = "{tenant_id}"'))
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
except Exception:
logger.exception("Error setting search_path.")
except Exception as e:
logger.error(f"Error setting search_path: {str(e)}")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise
@@ -323,77 +318,47 @@ async def get_async_session_with_tenant(
yield session
@contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""
Generate a database session for a specific tenant.
This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
3. Reverts to the previous tenant ID after the session is closed.
4. Uses the default schema if no tenant ID is provided.
"""
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
try:
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
yield session
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
finally:
# Restore the previous tenant ID
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
def set_search_path_on_checkout(

View File

@@ -219,7 +219,7 @@ def mark_attempt_partially_succeeded(
def mark_attempt_failed(
index_attempt_id: int,
index_attempt: IndexAttempt,
db_session: Session,
failure_reason: str = "Unknown",
full_exception_trace: str | None = None,
@@ -227,7 +227,7 @@ def mark_attempt_failed(
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()

View File

@@ -135,9 +135,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
hidden_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
)
recent_assistants: Mapped[list[dict]] = mapped_column(
postgresql.JSONB(), nullable=False, default=list, server_default="[]"
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
@@ -737,10 +734,9 @@ class IndexAttempt(Base):
full_exception_trace: Mapped[str | None] = mapped_column(Text, default=None)
# Nullable because in the past, we didn't allow swapping out embedding models live
search_settings_id: Mapped[int] = mapped_column(
ForeignKey("search_settings.id", ondelete="SET NULL"),
nullable=True,
ForeignKey("search_settings.id"),
nullable=False,
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -760,7 +756,7 @@ class IndexAttempt(Base):
"ConnectorCredentialPair", back_populates="index_attempts"
)
search_settings: Mapped[SearchSettings | None] = relationship(
search_settings: Mapped[SearchSettings] = relationship(
"SearchSettings", back_populates="index_attempts"
)
@@ -921,15 +917,10 @@ class ToolCall(Base):
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
message_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id"), nullable=False
)
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
# Update the relationship
message: Mapped["ChatMessage"] = relationship(
"ChatMessage",
back_populates="tool_call",
uselist=False,
"ChatMessage", back_populates="tool_calls"
)
@@ -1060,13 +1051,12 @@ class ChatMessage(Base):
secondary=ChatMessage__SearchDoc.__table__,
back_populates="chat_messages",
)
tool_call: Mapped["ToolCall"] = relationship(
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_calls: Mapped[list["ToolCall"]] = relationship(
"ToolCall",
back_populates="message",
uselist=False,
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=ChatMessage__StandardAnswer.__table__,
@@ -1324,6 +1314,7 @@ class StarterMessage(TypedDict):
in Postgres"""
name: str
description: str
message: str

View File

@@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import get_session_with_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
@@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with get_session_with_default_tenant() as db_session:
with get_session_with_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)

View File

@@ -14,6 +14,7 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.key_value_store.factory import get_kv_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -22,14 +23,7 @@ logger = setup_logger()
def check_index_swap(db_session: Session) -> SearchSettings | None:
"""Get count of cc-pairs and count of successful index_attempts for the
new model grouped by connector + credential, if it's the same, then assume
new index is done building. If so, swap the indices and expire the old one.
Returns None if search settings did not change, or the old search settings if they
did change.
"""
old_search_settings = None
new index is done building. If so, swap the indices and expire the old one."""
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
@@ -49,9 +43,9 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
# Swap indices
current_search_settings = get_current_search_settings(db_session)
now_old_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
search_settings=now_old_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
@@ -73,6 +67,6 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
old_search_settings = current_search_settings
return old_search_settings
if MULTI_TENANT:
return now_old_search_settings
return None

View File

@@ -1,111 +0,0 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.constants import TokenRateLimitScope
from danswer.db.models import TokenRateLimit
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
def fetch_all_user_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.USER
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(query).all()
def fetch_all_global_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
token_rate_limits = db_session.scalars(query).all()
return token_rate_limits
def insert_user_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def insert_global_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.GLOBAL,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def update_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
token_limit.enabled = token_rate_limit_settings.enabled
token_limit.token_budget = token_rate_limit_settings.token_budget
token_limit.period_hours = token_rate_limit_settings.period_hours
db_session.commit()
return token_limit
def delete_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
) -> None:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
db_session.query(TokenRateLimit__UserGroup).filter(
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
).delete()
db_session.delete(token_limit)
db_session.commit()

View File

@@ -24,13 +24,6 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
return tool
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
if not tool:
raise ValueError("Tool by specified name does not exist")
return tool
def create_tool(
name: str,
description: str | None,
@@ -44,7 +37,7 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.model_dump() for header in custom_headers]
custom_headers=[header.dict() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,

View File

@@ -147,7 +147,7 @@ class VespaIndex(DocumentIndex):
return None
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
logger.notice(f"Deploying Vespa application package to {deploy_url}")
logger.info(f"Deploying Vespa application package to {deploy_url}")
vespa_schema_path = os.path.join(
os.getcwd(), "danswer", "document_index", "vespa", "app_config"

View File

@@ -13,7 +13,6 @@ class ChatFileType(str, Enum):
DOC = "document"
# Plain text only contain the text
PLAIN_TEXT = "plain_text"
CSV = "csv"
class FileDescriptor(TypedDict):

View File

@@ -8,13 +8,12 @@ import requests
from sqlalchemy.orm import Session
from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_session_context_manager
from danswer.db.models import ChatMessage
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import FileDescriptor
from danswer.file_store.models import InMemoryChatFile
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def load_chat_file(
@@ -53,11 +52,11 @@ def load_all_chat_files(
return files
def save_file_from_url(url: str, tenant_id: str) -> str:
def save_file_from_url(url: str) -> str:
"""NOTE: using multiple sessions here, since this is often called
using multithreading. In practice, sharing a session has resulted in
weird errors."""
with get_session_with_tenant(tenant_id) as db_session:
with get_session_context_manager() as db_session:
response = requests.get(url)
response.raise_for_status()
@@ -76,10 +75,7 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
def save_files_from_urls(urls: list[str]) -> list[str]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
(save_file_from_url, (url, tenant_id)) for url in urls
(save_file_from_url, (url,)) for url in urls
]
# Must pass in tenant_id here, since this is called by multithreading
return run_functions_tuples_in_parallel(funcs)

View File

@@ -16,9 +16,9 @@ from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -1,44 +1,72 @@
import itertools
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
from uuid import uuid4
from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall
from langchain_core.messages import HumanMessage
from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.build import default_build_system_message
from danswer.llm.answering.prompts.build import default_build_user_message
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
from danswer.llm.answering.prompts.citations_prompt import (
build_citations_system_message,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
CitationResponseHandler,
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
from danswer.llm.answering.stream_processing.citation_processing import (
build_citation_processor,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
QuotesResponseHandler,
from danswer.llm.answering.stream_processing.quotes_processing import (
build_quotes_processor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
)
from danswer.tools.force import filter_tools_for_force_tool_use
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolResponse
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.images.prompt import build_image_generation_user_prompt
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.message import build_tool_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import (
check_which_tools_should_run_for_non_tool_calling_llm,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
@@ -46,9 +74,29 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_answer_stream_processor(
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
answer_style_configs: AnswerStyleConfig,
) -> StreamProcessor:
if answer_style_configs.citation_config:
return build_citation_processor(
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
)
if answer_style_configs.quotes_config:
return build_quotes_processor(
context_docs=context_docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak")
)
raise RuntimeError("Not implemented yet")
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
logger = setup_logger()
class Answer:
def __init__(
self,
@@ -88,6 +136,8 @@ class Answer:
self.tools = tools or []
self.force_use_tool = force_use_tool
self.skip_explicit_tool_calling = skip_explicit_tool_calling
self.message_history = message_history or []
# used for QA flow where we only want to send a single message
self.single_message_history = single_message_history
@@ -112,141 +162,335 @@ class Answer:
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False
self.using_tool_calling_llm = (
explicit_tool_calling_supported(
self.llm.config.model_provider, self.llm.config.model_name
def _update_prompt_builder_for_search_tool(
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
) -> None:
if self.answer_style_config.citation_config:
prompt_builder.update_system_prompt(
build_citations_system_message(self.prompt_config)
)
and not skip_explicit_tool_calling
)
def _get_tools_list(self) -> list[Tool]:
if not self.force_use_tool.force_use:
return self.tools
tool = next(
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
)
if tool is None:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
logger.info(
f"Forcefully using tool='{tool.name}'"
+ (
f" with args='{self.force_use_tool.args}'"
if self.force_use_tool.args is not None
else ""
)
)
return [tool]
def _handle_specified_tool_call(
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
) -> AnswerStream:
current_llm_call = llm_calls[-1]
# make a dummy tool handler
tool_handler = ToolResponseHandler([tool])
dummy_tool_call_chunk = AIMessageChunk(content="")
dummy_tool_call_chunk.tool_calls = [
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
]
response_handler_manager = LLMResponseHandlerManager(
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
)
yield from response_handler_manager.handle_llm_response(
iter([dummy_tool_call_chunk])
)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
else:
raise RuntimeError("Tool call handler did not return a new LLM call")
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
current_llm_call = llm_calls[-1]
# handle the case where no decision has to be made; we simply run the tool
if (
current_llm_call.force_use_tool.force_use
and current_llm_call.force_use_tool.args is not None
):
tool_name, tool_args = (
current_llm_call.force_use_tool.tool_name,
current_llm_call.force_use_tool.args,
)
tool = next(
(t for t in current_llm_call.tools if t.name == tool_name), None
)
if not tool:
raise RuntimeError(f"Tool '{tool_name}' not found")
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
# special pre-logic for non-tool calling LLM case
if not self.using_tool_calling_llm and current_llm_call.tools:
chosen_tool_and_args = (
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
current_llm_call, self.llm
prompt_builder.update_user_prompt(
build_citations_user_message(
question=self.question,
prompt_config=self.prompt_config,
context_docs=final_context_documents,
files=self.latest_query_files,
all_doc_useful=(
self.answer_style_config.citation_config.all_docs_useful
if self.answer_style_config.citation_config
else False
),
history_message=self.single_message_history or "",
)
)
elif self.answer_style_config.quotes_config:
prompt_builder.update_user_prompt(
build_quotes_user_message(
question=self.question,
context_docs=final_context_documents,
history_str=self.single_message_history or "",
prompt=self.prompt_config,
)
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
# if we're skipping gen ai answer generation, we should break
# out unless we're forcing a tool call. If we don't, we might generate an
# answer, which is a no-no!
if (
self.skip_gen_ai_answer_generation
and not current_llm_call.force_use_tool.force_use
):
def _raw_output_for_explicit_tool_calling_llms(
self,
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
# / need to generate the args
tool_call_chunk = AIMessageChunk(
content="",
)
tool_call_chunk.tool_calls = [
{
"name": self.force_use_tool.tool_name,
"args": self.force_use_tool.args,
"id": str(uuid4()),
}
]
else:
# if tool calling is supported, first try the raw message
# to see if we don't need to use any tools
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
final_tool_definitions = [
tool.tool_definition()
for tool in filter_tools_for_force_tool_use(
self.tools, self.force_use_tool
)
]
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool.force_use else None,
structured_response_format=self.answer_style_config.structured_response_format,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
if tool_call_chunk is None:
tool_call_chunk = message
else:
tool_call_chunk += message # type: ignore
else:
if message.content:
if self.is_cancelled:
return
yield cast(str, message.content)
if (
message.additional_kwargs.get("usage_metadata", {}).get("stop")
== "length"
):
yield StreamStopInfo(
stop_reason=StreamStopReason.CONTEXT_LENGTH
)
if not tool_call_chunk:
return # no tool call needed
# if we have a tool call, we need to call the tool
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
known_tools_by_name = [
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
continue
else:
tool = known_tools_by_name[0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
yield from tool_runner.tool_responses()
tool_call_summary = ToolCallSummary(
tool_call_request=tool_call_chunk,
tool_call_result=build_tool_message(
tool_call_request, tool_runner.tool_message_content()
),
)
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question, img_urls=img_urls
)
)
yield tool_runner.tool_final_result()
if not self.skip_gen_ai_answer_generation:
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
yield from self._process_llm_stream(
prompt=prompt,
# as of now, we don't support multiple tool calls in sequence, which is why
# we don't need to pass this in here
# tools=[tool.tool_definition() for tool in self.tools],
)
return
# set up "handlers" to listen to the LLM response stream and
# feed back the processed results + handle tool call requests
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
# This method processes the LLM stream and yields the content or stop information
def _process_llm_stream(
self,
prompt: Any,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[str | StreamStopInfo]:
for message in self.llm.stream(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
structured_response_format=self.answer_style_config.structured_response_format,
):
if isinstance(message, AIMessageChunk):
if message.content:
if self.is_cancelled:
return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield cast(str, message.content)
search_result = SearchTool.get_search_result(current_llm_call) or []
if (
message.additional_kwargs.get("usage_metadata", {}).get("stop")
== "length"
):
yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH)
answer_handler: AnswerResponseHandler
if self.answer_style_config.citation_config:
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
def _raw_output_for_non_explicit_tool_calling_llms(
self,
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
if self.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
iter(
[
tool
for tool in self.tools
if tool.name == self.force_use_tool.tool_name
]
),
None,
)
elif self.answer_style_config.quotes_config:
answer_handler = QuotesResponseHandler(
context_docs=search_result,
if not tool:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=self.question,
history=self.message_history,
llm=self.llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
chosen_tool_and_args = (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=self.tools,
query=self.question,
history=self.message_history,
llm=self.llm,
)
available_tools_and_args = [
(self.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=self.message_history,
query=self.question,
llm=self.llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
if self.skip_gen_ai_answer_generation:
raise ValueError(
"skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated"
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)
return
tool, tool_args = chosen_tool_and_args
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_documents = cast(list[LlmDoc], response.response)
yield response
if final_context_documents is None:
raise RuntimeError(
f"{tool.name} did not return final context documents"
)
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
)
elif tool.name == ImageGenerationTool._NAME:
img_urls = []
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], response.response
)
img_urls = [img.url for img in img_generation_response]
yield response
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
img_urls=img_urls,
)
)
else:
raise ValueError("No answer style config provided")
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name,
*tool_runner.tool_responses(),
)
)
)
final = tool_runner.tool_final_result()
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled
)
yield final
if not self.skip_gen_ai_answer_generation:
prompt = prompt_builder.build()
# DEBUG: good breakpoint
stream = self.llm.stream(
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
"required"
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
else None
),
structured_response_format=self.answer_style_config.structured_response_format,
)
yield from response_handler_manager.handle_llm_response(stream)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
yield from self._process_llm_stream(prompt=prompt, tools=None)
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -254,30 +498,94 @@ class Answer:
yield from self._processed_stream
return
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=self.question,
prompt_config=self.prompt_config,
files=self.latest_query_files,
),
message_history=self.message_history,
llm_config=self.llm.config,
single_message_history=self.single_message_history,
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
llm_call = LLMCall(
prompt_builder=prompt_builder,
tools=self._get_tools_list(),
force_use_tool=self.force_use_tool,
files=self.latest_query_files,
tool_call_info=[],
using_tool_calling_llm=self.using_tool_calling_llm,
output_generator = (
self._raw_output_for_explicit_tool_calling_llms()
if explicit_tool_calling_supported(
self.llm.config.model_provider, self.llm.config.model_name
)
and not self.skip_explicit_tool_calling
else self._raw_output_for_non_explicit_tool_calling_llms()
)
def _process_stream(
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
) -> AnswerStream:
message = None
# special things we need to keep track of for the SearchTool
# raw results that will be displayed to the user
search_results: list[LlmDoc] | None = None
# processed docs to feed into the LLM
final_context_docs: list[LlmDoc] | None = None
for message in stream:
if isinstance(message, ToolCallKickoff) or isinstance(
message, ToolCallFinalResult
):
yield message
elif isinstance(message, ToolResponse):
if message.id == SEARCH_RESPONSE_SUMMARY_ID:
# We don't need to run section merging in this flow, this variable is only used
# below to specify the ordering of the documents for the purpose of matching
# citations to the right search documents. The deduplication logic is more lightweight
# there and we don't need to do it twice
search_results = [
llm_doc_from_inference_section(section)
for section in cast(
SearchResponseSummary, message.response
).top_sections
]
elif message.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_docs = cast(list[LlmDoc], message.response)
yield message
elif (
message.id == SEARCH_DOC_CONTENT_ID
and not self._return_contexts
):
continue
yield message
else:
# assumes all tool responses will come first, then the final answer
break
if not self.skip_gen_ai_answer_generation:
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
stream_stop_info = None
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
for item in itertools.chain([message], stream):
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return
# this should never happen, but we're seeing weird behavior here so handling for now
if not isinstance(item, str):
logger.error(
f"Received non-string item in answer stream: {item}. Skipping."
)
continue
yield item
yield from process_answer_stream_fn(_stream())
if stream_stop_info:
yield stream_stop_info
processed_stream = []
for processed_packet in self._get_response([llm_call]):
for processed_packet in _process_stream(output_generator):
processed_stream.append(processed_packet)
yield processed_packet
@@ -301,6 +609,7 @@ class Answer:
return citations
@property
def is_cancelled(self) -> bool:
if self._is_cancelled:
return True

View File

@@ -1,84 +0,0 @@
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import TYPE_CHECKING
from langchain_core.messages import BaseMessage
from pydantic.v1 import BaseModel as BaseModel__v1
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
if TYPE_CHECKING:
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
ResponsePart = (
DanswerAnswerPiece
| CitationInfo
| DanswerQuotes
| ToolCallKickoff
| ToolResponse
| ToolCallFinalResult
| StreamStopInfo
)
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]
force_use_tool: ForceUseTool
files: list[InMemoryChatFile]
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
using_tool_calling_llm: bool
class Config:
arbitrary_types_allowed = True
class LLMResponseHandlerManager:
def __init__(
self,
tool_handler: "ToolResponseHandler",
answer_handler: "AnswerResponseHandler",
is_cancelled: Callable[[], bool],
):
self.tool_handler = tool_handler
self.answer_handler = answer_handler
self.is_cancelled = is_cancelled
def handle_llm_response(
self,
stream: Iterator[BaseMessage],
) -> Generator[ResponsePart, None, None]:
all_messages: list[BaseMessage] = []
for message in stream:
if self.is_cancelled():
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
return
# tool handler doesn't do anything until the full message is received
# NOTE: still need to run list() to get this to run
list(self.tool_handler.handle_response_part(message, all_messages))
yield from self.answer_handler.handle_response_part(message, all_messages)
all_messages.append(message)
# potentially give back all info on the selected tool call + its result
yield from self.tool_handler.handle_response_part(None, all_messages)
yield from self.answer_handler.handle_response_part(None, all_messages)
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
return self.tool_handler.next_llm_call(llm_call)

View File

@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_call: ToolCallFinalResult | None
tool_calls: list[ToolCallFinalResult]
@classmethod
def from_chat_message(
@@ -51,13 +51,14 @@ class PreviousMessage(BaseModel):
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
)
def to_langchain_msg(self) -> BaseMessage:

View File

@@ -12,12 +12,12 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.llm.utils import translate_history_to_basemessages
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import drop_messages_history_overflow
from danswer.tools.message import ToolCallSummary
def default_build_system_message(
@@ -54,14 +54,18 @@ def default_build_user_message(
class AnswerPromptBuilder:
def __init__(
self,
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
single_message_history: str | None = None,
self, message_history: list[PreviousMessage], llm_config: LLMConfig
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(message_history)
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None
llm_tokenizer = get_tokenizer(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
@@ -70,24 +74,6 @@ class AnswerPromptBuilder:
Callable[[str], list[int]], llm_tokenizer.encode
)
self.raw_message_history = message_history
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(message_history)
# for cases where like the QA flow where we want to condense the chat history
# into a single message rather than a sequence of User / Assistant messages
self.single_message_history = single_message_history
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
self.system_message_and_token_cnt = None
@@ -99,21 +85,18 @@ class AnswerPromptBuilder:
)
def update_user_prompt(self, user_message: HumanMessage) -> None:
if not user_message:
self.user_message_and_token_cnt = None
return
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
)
def append_message(self, message: BaseMessage) -> None:
"""Append a new message to the message history."""
token_count = check_message_tokens(message, self.llm_tokenizer_encode_func)
self.new_messages_and_token_cnts.append((message, token_count))
def get_user_message_content(self) -> str:
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
return query
def build(self) -> list[BaseMessage]:
def build(
self, tool_call_summary: ToolCallSummary | None = None
) -> list[BaseMessage]:
if not self.user_message_and_token_cnt:
raise ValueError("User message must be set before building prompt")
@@ -130,8 +113,25 @@ class AnswerPromptBuilder:
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if self.new_messages_and_token_cnts:
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
if tool_call_summary:
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_request,
check_message_tokens(
tool_call_summary.tool_call_request,
self.llm_tokenizer_encode_func,
),
)
)
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_result,
check_message_tokens(
tool_call_summary.tool_call_result,
self.llm_tokenizer_encode_func,
),
)
)
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -6,6 +6,7 @@ from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MA
from danswer.db.models import Persona
from danswer.db.persona import get_default_prompt__read_only
from danswer.db.search_settings import get_multilingual_expansion
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
@@ -13,7 +14,6 @@ from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
@@ -132,9 +132,10 @@ def build_citations_system_message(
def build_citations_user_message(
message: HumanMessage,
question: str,
prompt_config: PromptConfig,
context_docs: list[LlmDoc] | list[InferenceChunk],
files: list[InMemoryChatFile],
all_doc_useful: bool,
history_message: str = "",
) -> HumanMessage:
@@ -148,7 +149,6 @@ def build_citations_user_message(
if history_message
else ""
)
query, img_urls = message_to_prompt_and_imgs(message)
if context_docs:
context_docs_str = build_complete_context_str(context_docs)
@@ -158,22 +158,20 @@ def build_citations_user_message(
optional_ignore_statement=optional_ignore,
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
user_query=query,
user_query=question,
history_block=history_block,
)
else:
# if no context docs provided, assume we're in the tool calling flow
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
task_prompt=task_prompt_with_reminder,
user_query=query,
user_query=question,
history_block=history_block,
)
user_prompt = user_prompt.strip()
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, img_urls=img_urls)
if img_urls
else user_prompt
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
)
return user_msg

View File

@@ -5,7 +5,6 @@ from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
@@ -76,7 +75,7 @@ def _build_strong_llm_quotes_prompt(
def build_quotes_user_message(
message: HumanMessage,
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
@@ -87,10 +86,28 @@ def build_quotes_user_message(
else _build_strong_llm_quotes_prompt
)
query, _ = message_to_prompt_and_imgs(message)
return prompt_builder(
question=query,
question=question,
context_docs=context_docs,
history_str=history_str,
prompt=prompt,
)
def build_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
prompt_builder = (
_build_weak_llm_quotes_prompt
if QA_PROMPT_OVERRIDE == "weak"
else _build_strong_llm_quotes_prompt
)
return prompt_builder(
question=question,
context_docs=context_docs,
history_str=history_str,
prompt=prompt,

View File

@@ -19,7 +19,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
from danswer.tools.search.search_utils import section_to_dict
from danswer.utils.logger import setup_logger

View File

@@ -1,91 +0,0 @@
import abc
from collections.abc import Generator
from langchain_core.messages import BaseMessage
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.llm.answering.stream_processing.citation_processing import (
CitationProcessor,
)
from danswer.llm.answering.stream_processing.quotes_processing import (
QuotesProcessor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
class AnswerResponseHandler(abc.ABC):
@abc.abstractmethod
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
raise NotImplementedError
class DummyAnswerResponseHandler(AnswerResponseHandler):
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
# This is a dummy handler that returns nothing
yield from []
class CitationResponseHandler(AnswerResponseHandler):
def __init__(
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.citation_processor = CitationProcessor(
context_docs=self.context_docs,
doc_id_to_rank_map=self.doc_id_to_rank_map,
)
self.processed_text = ""
self.citations: list[CitationInfo] = []
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
return
content = (
response_item.content if isinstance(response_item.content, str) else ""
)
# Process the new content through the citation processor
yield from self.citation_processor.process_token(content)
class QuotesResponseHandler(AnswerResponseHandler):
def __init__(
self,
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
):
self.quotes_processor = QuotesProcessor(
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
yield from self.quotes_processor.process_token(None)
return
content = (
response_item.content if isinstance(response_item.content, str) else ""
)
yield from self.quotes_processor.process_token(content)

View File

@@ -1,10 +1,12 @@
import re
from collections.abc import Generator
from collections.abc import Iterator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.utils.logger import setup_logger
@@ -17,104 +19,128 @@ def in_code_block(llm_text: str) -> bool:
return count % 2 != 0
class CitationProcessor:
def __init__(
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.stop_stream = stop_stream
self.order_mapping = doc_id_to_rank_map.order_mapping
self.llm_out = ""
self.max_citation_num = len(context_docs)
self.citation_order: list[int] = []
self.curr_segment = ""
self.cited_inds: set[int] = set()
self.hold = ""
self.current_citations: list[int] = []
self.past_cite_count = 0
def extract_citations_from_stream(
tokens: Iterator[str],
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT,
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
"""
Key aspects:
def process_token(
self, token: str | None
) -> Generator[DanswerAnswerPiece | CitationInfo, None, None]:
# None -> end of stream
if token is None:
yield DanswerAnswerPiece(answer_piece=self.curr_segment)
return
1. Stream Processing:
- Processes tokens one by one, allowing for real-time handling of large texts.
if self.stop_stream:
next_hold = self.hold + token
if self.stop_stream in next_hold:
return
if next_hold == self.stop_stream[: len(next_hold)]:
self.hold = next_hold
return
2. Citation Detection:
- Uses regex to find citations in the format [number].
- Example: [1], [2], etc.
3. Citation Mapping:
- Maps detected citation numbers to actual document ranks using doc_id_to_rank_map.
- Example: [1] might become [3] if doc_id_to_rank_map maps it to 3.
4. Citation Formatting:
- Replaces citations with properly formatted versions.
- Adds links if available: [[1]](https://example.com)
- Handles cases where links are not available: [[1]]()
5. Duplicate Handling:
- Skips consecutive citations of the same document to avoid redundancy.
6. Output Generation:
- Yields DanswerAnswerPiece objects for regular text.
- Yields CitationInfo objects for each unique citation encountered.
7. Context Awareness:
- Uses context_docs to access document information for citations.
This function effectively processes a stream of text, identifies and reformats citations,
and provides both the processed text and citation information as output.
"""
order_mapping = doc_id_to_rank_map.order_mapping
llm_out = ""
max_citation_num = len(context_docs)
citation_order = []
curr_segment = ""
cited_inds = set()
hold = ""
raw_out = ""
current_citations: list[int] = []
past_cite_count = 0
for raw_token in tokens:
raw_out += raw_token
if stop_stream:
next_hold = hold + raw_token
if stop_stream in next_hold:
break
if next_hold == stop_stream[: len(next_hold)]:
hold = next_hold
continue
token = next_hold
self.hold = ""
hold = ""
else:
token = raw_token
self.curr_segment += token
self.llm_out += token
curr_segment += token
llm_out += token
# Handle code blocks without language tags
if "`" in self.curr_segment:
if self.curr_segment.endswith("`"):
return
elif "```" in self.curr_segment:
piece_that_comes_after = self.curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
if "`" in curr_segment:
if curr_segment.endswith("`"):
continue
elif "```" in curr_segment:
piece_that_comes_after = curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(llm_out):
curr_segment = curr_segment.replace("```", "```plaintext")
citation_pattern = r"\[(\d+)\]"
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
citations_found = list(re.finditer(citation_pattern, curr_segment))
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(
possible_citation_pattern, self.curr_segment
)
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
self.current_citations = []
# `past_cite_count`: number of characters since past citation
# 5 to ensure a citation hasn't occured
if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5:
current_citations = []
result = "" # Initialize result here
if citations_found and not in_code_block(self.llm_out):
if citations_found and not in_code_block(llm_out):
last_citation_end = 0
length_to_add = 0
while len(citations_found) > 0:
citation = citations_found.pop(0)
numerical_value = int(citation.group(1))
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
real_citation_num = self.order_mapping[context_llm_doc.document_id]
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[numerical_value - 1]
real_citation_num = order_mapping[context_llm_doc.document_id]
if real_citation_num not in self.citation_order:
self.citation_order.append(real_citation_num)
if real_citation_num not in citation_order:
citation_order.append(real_citation_num)
target_citation_num = (
self.citation_order.index(real_citation_num) + 1
)
target_citation_num = citation_order.index(real_citation_num) + 1
# Skip consecutive citations of the same work
if target_citation_num in self.current_citations:
if target_citation_num in current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
curr_segment = (
curr_segment[: length_to_add + start]
+ curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
# by allowing it to generate citations on its own.
if curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
context_llm_doc = context_docs[doc_id - 1]
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
@@ -124,57 +150,75 @@ class CitationProcessor:
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
# Will continue attempt on next loops
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
link = context_llm_doc.link
# Replace the citation in the current segment
start, end = citation.span()
self.curr_segment = (
self.curr_segment[: start + length_to_add]
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[{target_citation_num}]"
+ self.curr_segment[end + length_to_add :]
+ curr_segment[end + length_to_add :]
)
self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num)
past_cite_count = len(llm_out)
current_citations.append(target_citation_num)
if target_citation_num not in self.cited_inds:
self.cited_inds.add(target_citation_num)
if target_citation_num not in cited_inds:
cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
prev_length = len(curr_segment)
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]({link})"
+ self.curr_segment[end + length_to_add :]
+ curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
length_to_add += len(curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
prev_length = len(curr_segment)
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]()"
+ self.curr_segment[end + length_to_add :]
+ curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
length_to_add += len(curr_segment) - prev_length
last_citation_end = end + length_to_add
if last_citation_end > 0:
result += self.curr_segment[:last_citation_end]
self.curr_segment = self.curr_segment[last_citation_end:]
yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end])
curr_segment = curr_segment[last_citation_end:]
if possible_citation_found:
continue
yield DanswerAnswerPiece(answer_piece=curr_segment)
curr_segment = ""
if not possible_citation_found:
result += self.curr_segment
self.curr_segment = ""
if curr_segment:
yield DanswerAnswerPiece(answer_piece=curr_segment)
if result:
yield DanswerAnswerPiece(answer_piece=result)
def build_citation_processor(
context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
) -> StreamProcessor:
def stream_processor(
tokens: Iterator[str],
) -> AnswerQuestionStreamReturn:
yield from extract_citations_from_stream(
tokens=tokens,
context_docs=context_docs,
doc_id_to_rank_map=doc_id_to_rank_map,
)
return stream_processor

View File

@@ -1,11 +1,14 @@
import math
import re
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Optional
import regex
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuote
@@ -154,7 +157,7 @@ def separate_answer_quotes(
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
def _process_answer(
def process_answer(
answer_raw: str,
docs: list[LlmDoc],
is_json_prompt: bool = True,
@@ -192,7 +195,7 @@ def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
def _extract_quotes_from_completed_token_stream(
model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True
) -> DanswerQuotes:
answer, quotes = _process_answer(model_output, context_docs, is_json_prompt)
answer, quotes = process_answer(model_output, context_docs, is_json_prompt)
if answer:
logger.notice(answer)
elif model_output:
@@ -201,101 +204,94 @@ def _extract_quotes_from_completed_token_stream(
return quotes
class QuotesProcessor:
def __init__(
self,
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
):
self.context_docs = context_docs
self.is_json_prompt = is_json_prompt
def process_model_tokens(
tokens: Iterator[str],
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
"""Used in the streaming case to process the model output
into an Answer and Quotes
self.found_answer_start = False if is_json_prompt else True
self.found_answer_end = False
self.hold_quote = ""
self.model_output = ""
self.hold = ""
Yields Answer tokens back out in a dict for streaming to frontend
When Answer section ends, yields dict with answer_finished key
Collects all the tokens at the end to form the complete model output"""
quote_pat = f"\n{QUOTE_PAT}"
# Sometimes worse model outputs new line instead of :
quote_loose = f"\n{quote_pat[:-1]}\n"
# Sometime model outputs two newlines before quote section
quote_pat_full = f"\n{quote_pat}"
model_output: str = ""
found_answer_start = False if is_json_prompt else True
found_answer_end = False
hold_quote = ""
def process_token(
self, token: str | None
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
# None -> end of stream
if token is None:
if self.model_output:
yield _extract_quotes_from_completed_token_stream(
model_output=self.model_output,
context_docs=self.context_docs,
is_json_prompt=self.is_json_prompt,
)
return
for token in tokens:
model_previous = model_output
model_output += token
model_previous = self.model_output
self.model_output += token
if not self.found_answer_start:
m = answer_pattern.search(self.model_output)
if not found_answer_start:
m = answer_pattern.search(model_output)
if m:
self.found_answer_start = True
found_answer_start = True
# Prevent heavy cases of hallucinations
if self.is_json_prompt and len(self.model_output) > 70:
# Prevent heavy cases of hallucinations where model is never providing a JSON
# We want to quickly update the user - not stream forever
if is_json_prompt and len(model_output) > 70:
logger.warning("LLM did not produce json as prompted")
self.found_answer_end = True
return
found_answer_end = True
continue
remaining = self.model_output[m.end() :]
# Look for an unescaped quote, which means the answer is entirely contained
# in this token e.g. if the token is `{"answer": "blah", "qu`
quote_indices = [i for i, char in enumerate(remaining) if char == '"']
for quote_idx in quote_indices:
# Check if quote is escaped by counting backslashes before it
num_backslashes = 0
pos = quote_idx - 1
while pos >= 0 and remaining[pos] == "\\":
num_backslashes += 1
pos -= 1
# If even number of backslashes, quote is not escaped
if num_backslashes % 2 == 0:
yield DanswerAnswerPiece(answer_piece=remaining[:quote_idx])
return
# If no unescaped quote found, yield the remaining string
remaining = model_output[m.end() :]
if len(remaining) > 0:
yield DanswerAnswerPiece(answer_piece=remaining)
return
continue
if self.found_answer_start and not self.found_answer_end:
if self.is_json_prompt and _stream_json_answer_end(model_previous, token):
self.found_answer_end = True
if found_answer_start and not found_answer_end:
if is_json_prompt and _stream_json_answer_end(model_previous, token):
found_answer_end = True
# return the remaining part of the answer e.g. token might be 'd.", ' and we should yield 'd.'
if token:
try:
answer_token_section = token.index('"')
yield DanswerAnswerPiece(
answer_piece=self.hold_quote + token[:answer_token_section]
answer_piece=hold_quote + token[:answer_token_section]
)
except ValueError:
logger.error("Quotation mark not found in token")
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
yield DanswerAnswerPiece(answer_piece=None)
return
elif not self.is_json_prompt:
quote_pat = f"\n{QUOTE_PAT}"
quote_loose = f"\n{quote_pat[:-1]}\n"
quote_pat_full = f"\n{quote_pat}"
if (
quote_pat in self.hold_quote + token
or quote_loose in self.hold_quote + token
):
self.found_answer_end = True
continue
elif not is_json_prompt:
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
found_answer_end = True
yield DanswerAnswerPiece(answer_piece=None)
return
if self.hold_quote + token in quote_pat_full:
self.hold_quote += token
return
continue
if hold_quote + token in quote_pat_full:
hold_quote += token
continue
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
hold_quote = ""
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
self.hold_quote = ""
logger.debug(f"Raw Model QnA Output: {model_output}")
yield _extract_quotes_from_completed_token_stream(
model_output=model_output,
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
def build_quotes_processor(
context_docs: list[LlmDoc], is_json_prompt: bool
) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]:
def stream_processor(
tokens: Iterator[str],
) -> AnswerQuestionStreamReturn:
yield from process_model_tokens(
tokens=tokens,
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
return stream_processor

View File

@@ -1,207 +0,0 @@
from collections.abc import Generator
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.llm.interfaces import LLM
from danswer.tools.force import ForceUseTool
from danswer.tools.message import build_tool_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_runner import (
check_which_tools_should_run_for_non_tool_calling_llm,
)
from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.utils.logger import setup_logger
logger = setup_logger()
class ToolResponseHandler:
def __init__(self, tools: list[Tool]):
self.tools = tools
self.tool_call_chunk: AIMessageChunk | None = None
self.tool_call_requests: list[ToolCall] = []
self.tool_runner: ToolRunner | None = None
self.tool_call_summary: ToolCallSummary | None = None
self.tool_kickoff: ToolCallKickoff | None = None
self.tool_responses: list[ToolResponse] = []
self.tool_final_result: ToolCallFinalResult | None = None
@classmethod
def get_tool_call_for_non_tool_calling_llm(
cls, llm_call: LLMCall, llm: LLM
) -> tuple[Tool, dict] | None:
if llm_call.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
(
t
for t in llm_call.tools
if t.name == llm_call.force_use_tool.tool_name
),
None,
)
if not tool:
raise RuntimeError(
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
)
tool_args = (
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.get_user_message_content(),
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.get_user_message_content(),
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
available_tools_and_args = [
(llm_call.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.get_user_message_content(),
llm=llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
return chosen_tool_and_args
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
return
self.tool_call_requests = self.tool_call_chunk.tool_calls
selected_tool: Tool | None = None
selected_tool_call_request: ToolCall | None = None
for tool_call_request in self.tool_call_requests:
known_tools_by_name = [
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
continue
else:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
if selected_tool and selected_tool_call_request:
break
if not selected_tool or not selected_tool_call_request:
return
logger.info(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"])
self.tool_kickoff = self.tool_runner.kickoff()
yield self.tool_kickoff
for response in self.tool_runner.tool_responses():
self.tool_responses.append(response)
yield response
self.tool_final_result = self.tool_runner.tool_final_result()
yield self.tool_final_result
self.tool_call_summary = ToolCallSummary(
tool_call_request=self.tool_call_chunk,
tool_call_result=build_tool_message(
selected_tool_call_request, self.tool_runner.tool_message_content()
),
)
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
yield from self._handle_tool_call()
if isinstance(response_item, AIMessageChunk) and (
response_item.tool_call_chunks or response_item.tool_calls
):
if self.tool_call_chunk is None:
self.tool_call_chunk = response_item
else:
self.tool_call_chunk += response_item # type: ignore
return
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
if (
self.tool_runner is None
or self.tool_call_summary is None
or self.tool_kickoff is None
or self.tool_final_result is None
):
return None
tool_runner = self.tool_runner
new_prompt_builder = tool_runner.tool.build_next_prompt(
prompt_builder=current_llm_call.prompt_builder,
tool_call_summary=self.tool_call_summary,
tool_responses=self.tool_responses,
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
)
return LLMCall(
prompt_builder=new_prompt_builder,
tools=[], # for now, only allow one tool call per response
force_use_tool=ForceUseTool(
force_use=False,
tool_name="",
args=None,
),
files=current_llm_call.files,
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
tool_call_info=[
self.tool_kickoff,
*self.tool_responses,
self.tool_final_result,
],
)

View File

@@ -83,10 +83,8 @@ def _convert_litellm_message_to_langchain_message(
"args": json.loads(tool_call.function.arguments),
"id": tool_call.id,
}
for tool_call in tool_calls
]
if tool_calls
else [],
for tool_call in (tool_calls if tool_calls else [])
],
)
elif role == "system":
return SystemMessage(content=content)

Some files were not shown because too many files have changed in this diff Show More