Compare commits

..

22 Commits

Author SHA1 Message Date
pablodanswer
8f67f1715c minor typing 2024-10-20 14:48:19 -07:00
pablodanswer
3b365509e2 k 2024-10-20 14:41:12 -07:00
pablodanswer
022cbdfccf robustified cloud auth type 2024-10-20 14:28:22 -07:00
pablodanswer
ebec6f6b10 k 2024-10-20 13:43:08 -07:00
pablodanswer
1cad9c7b3d add cloud auth type 2024-10-20 13:43:08 -07:00
pablodanswer
b4e975013c k 2024-10-20 13:42:38 -07:00
pablodanswer
dd26f92206 nit 2024-10-20 13:41:41 -07:00
pablodanswer
4d00ec45ad remove comments + notice logs 2024-10-20 13:34:13 -07:00
pablodanswer
1a81c67a67 k 2024-10-20 13:22:00 -07:00
pablodanswer
04f965e656 k 2024-10-20 11:52:24 -07:00
pablodanswer
277d37e0ee fix 2024-10-20 11:45:00 -07:00
pablodanswer
3cd260131b k 2024-10-20 10:16:19 -07:00
pablodanswer
ad21ee0e9a fix mysterious syncing issue! 2024-10-19 19:26:57 -07:00
pablodanswer
c7dc0e9af0 k 2024-10-19 19:15:55 -07:00
pablodanswer
75c5de802b ensure tenant id passed 2024-10-19 19:15:55 -07:00
pablodanswer
c39f590d0d k 2024-10-19 19:15:55 -07:00
pablodanswer
82a9fda846 add types 2024-10-19 19:15:55 -07:00
pablodanswer
842d4ab2a8 k 2024-10-19 19:15:55 -07:00
pablodanswer
cddcec4ea4 k 2024-10-19 19:15:55 -07:00
pablodanswer
09dd7b424c validated workaround for flush + reset 2024-10-19 19:15:55 -07:00
pablodanswer
a2fd8d5e0a add some more multi tenancy 2024-10-19 19:15:55 -07:00
pablodanswer
802dc00f78 k 2024-10-19 19:15:55 -07:00
931 changed files with 19977 additions and 53147 deletions

View File

@@ -6,24 +6,20 @@
[Describe the tests you ran to verify your changes]
## Accepted Risk (provide if relevant)
N/A
## Accepted Risk
[Any know risks or failure modes to point out to reviewers]
## Related Issue(s) (provide if relevant)
N/A
## Related Issue(s)
[If applicable, link to the issue(s) this PR addresses]
## Mental Checklist:
- All of the automated tests pass
- All PR comments are addressed and marked resolved
- If there are migrations, they have been rebased to latest main
- If there are new dependencies, they are added to the requirements
- If there are new environment variables, they are added to all of the deployment methods
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- Docker images build and basic functionalities work
- Author has done a final read through of the PR right before merge
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
## Checklist:
- [ ] All of the automated tests pass
- [ ] All PR comments are addressed and marked resolved
- [ ] If there are migrations, they have been rebased to latest main
- [ ] If there are new dependencies, they are added to the requirements
- [ ] If there are new environment variables, they are added to all of the deployment methods
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- [ ] Docker images build and basic functionalities work
- [ ] Author has done a final read through of the PR right before merge

View File

@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }}
REGISTRY_IMAGE: danswer/danswer-backend
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: investigate a matrix build like the web container
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
trivyignores: ./backend/.trivyignore
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
trivyignores: ./backend/.trivyignore

View File

@@ -1,137 +0,0 @@
name: Build and Push Cloud Web Image on Tag
# Identical to the web container build, but with correct image tag and build args
on:
push:
tags:
- "*"
env:
REGISTRY_IMAGE: danswer/danswer-web-server-cloud
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
- runs-on
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
- run-id=${{ github.run_id }}
- tag=platform-${{ matrix.platform }}
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: ${{ matrix.platform }}
push: true
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
NEXT_PUBLIC_CLOUD_ENABLED=true
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
merge:
runs-on: ubuntu-latest
needs:
- build
steps:
- name: Download digests
uses: actions/download-artifact@v4
with:
path: /tmp/digests
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"

View File

@@ -3,53 +3,53 @@ name: Build and Push Model Server Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }}
REGISTRY_IMAGE: danswer/danswer-model-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -1,23 +0,0 @@
name: 'Nightly - Close stale issues and PRs'
on:
schedule:
- cron: '0 11 * * *' # Runs every day at 3 AM PST / 4 AM PDT / 11 AM UTC
permissions:
# contents: write # only for delete-branch option
issues: write
pull-requests: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
with:
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
close-issue-message: 'This issue was closed because it has been stalled for 90 days with no activity.'
close-pr-message: 'This PR was closed because it has been stalled for 90 days with no activity.'
days-before-stale: 75
# days-before-close: 90 # uncomment after we test stale behavior

View File

@@ -1,76 +0,0 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
security-events: write
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@v2
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: ${{ always() }}
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# with:
# sarif_file: trivy-results.sarif

View File

@@ -13,10 +13,7 @@ on:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
@@ -75,7 +72,7 @@ jobs:
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
@@ -88,58 +85,7 @@ jobs:
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# Start containers for multi-tenant tests
- name: Start Docker containers for multi-tenant tests
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
MULTI_TENANT=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
- name: Run Multi-Tenant Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
danswer/danswer-integration:test \
/app/tests/integration/multitenant_tests
continue-on-error: true
id: run_multitenant_tests
- name: Check multi-tenant test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Start Docker containers
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
@@ -184,7 +130,7 @@ jobs:
done
echo "Finished waiting for service."
- name: Run Standard Integration Tests
- name: Run integration tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
@@ -198,13 +144,8 @@ jobs:
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
danswer/danswer-integration:test
continue-on-error: true
id: run_tests
@@ -217,18 +158,12 @@ jobs:
echo "All integration tests passed successfully."
fi
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()

View File

@@ -1,124 +0,0 @@
name: Backport on Merge
# Note this workflow does not trigger the builds, be sure to manually tag the branches to trigger the builds
on:
pull_request:
types: [closed] # Later we check for merge so only PRs that go in can get backported
permissions:
contents: write
actions: write
jobs:
backport:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest
env:
GITHUB_TOKEN: ${{ secrets.YUHONG_GH_ACTIONS }}
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
fetch-depth: 0
- name: Set up Git user
run: |
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@danswer.ai"
git fetch --prune
- name: Check for Backport Checkbox
id: checkbox-check
run: |
PR_BODY="${{ github.event.pull_request.body }}"
if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then
echo "backport=true" >> $GITHUB_OUTPUT
else
echo "backport=false" >> $GITHUB_OUTPUT
fi
- name: List and sort release branches
id: list-branches
run: |
git fetch --all --tags
BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr)
BETA=$(echo "$BRANCHES" | head -n 1)
STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1)
echo "beta=release/$BETA" >> $GITHUB_OUTPUT
echo "stable=release/$STABLE" >> $GITHUB_OUTPUT
# Fetch latest tags for beta and stable
LATEST_BETA_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*-beta.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$" | grep -v -- "-cloud" | sort -Vr | head -n 1)
LATEST_STABLE_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+$" | sort -Vr | head -n 1)
# Handle case where no beta tags exist
if [[ -z "$LATEST_BETA_TAG" ]]; then
NEW_BETA_TAG="v1.0.0-beta.1"
else
NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 "." $3 "-beta." ($NF+1)}')
fi
# Increment latest stable tag
NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT
echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT
echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT
- name: Echo branch and tag information
run: |
echo "Beta branch: ${{ steps.list-branches.outputs.beta }}"
echo "Stable branch: ${{ steps.list-branches.outputs.stable }}"
echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}"
echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}"
echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}"
echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}"
- name: Trigger Backport
if: steps.checkbox-check.outputs.backport == 'true'
run: |
set -e
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
# Echo the merge commit SHA
echo "Merge commit SHA: ${{ github.event.pull_request.merge_commit_sha }}"
# Fetch all history for all branches and tags
git fetch --prune
# Reset and prepare the beta branch
git checkout ${{ steps.list-branches.outputs.beta }}
echo "Last 5 commits on beta branch:"
git log -n 5 --pretty=format:"%H"
echo "" # Newline for formatting
# Cherry-pick the merge commit from the merged PR
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
echo "Cherry-pick to beta failed due to conflicts."
exit 1
}
# Create new beta branch/tag
git tag ${{ steps.list-branches.outputs.new_beta_tag }}
# Push the changes and tag to the beta branch using PAT
git push origin ${{ steps.list-branches.outputs.beta }}
git push origin ${{ steps.list-branches.outputs.new_beta_tag }}
# Reset and prepare the stable branch
git checkout ${{ steps.list-branches.outputs.stable }}
echo "Last 5 commits on stable branch:"
git log -n 5 --pretty=format:"%H"
echo "" # Newline for formatting
# Cherry-pick the merge commit from the merged PR
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
echo "Cherry-pick to stable failed due to conflicts."
exit 1
}
# Create new stable branch/tag
git tag ${{ steps.list-branches.outputs.new_stable_tag }}
# Push the changes and tag to the stable branch using PAT
git push origin ${{ steps.list-branches.outputs.stable }}
git push origin ${{ steps.list-branches.outputs.new_stable_tag }}

View File

@@ -1,225 +0,0 @@
name: Run Chromatic Tests
concurrency:
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
jobs:
playwright-tests:
name: Playwright Tests
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Install playwright browsers
working-directory: ./web
run: npx playwright install --with-deps
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Web Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-web-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: danswer/danswer-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run pytest playwright test init
working-directory: ./backend
env:
PYTEST_IGNORE_SKIP: true
run: pytest -s tests/integration/tests/playwright/test_playwright.py
- name: Run Playwright tests
working-directory: ./web
run: npx playwright test
- uses: actions/upload-artifact@v4
if: always()
with:
# Chromatic automatically defaults to the test-results directory.
# Replace with the path to your custom directory and adjust the CHROMATIC_ARCHIVE_LOCATION environment variable accordingly.
name: test-results
path: ./web/test-results
retention-days: 30
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
chromatic-tests:
name: Chromatic Tests
needs: playwright-tests
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Download Playwright test results
uses: actions/download-artifact@v4
with:
name: test-results
path: ./web/test-results
- name: Run Chromatic
uses: chromaui/action@latest
with:
playwright: true
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: ./web
env:
CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -1,72 +0,0 @@
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
workflow_dispatch: # Allows manual triggering
jobs:
helm-chart-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
run: |
echo "default_branch: ${{ github.event.repository.default_branch }}"
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
run: ct lint --config ct.yaml --all
# the following would lint only changed charts, but linting isn't expensive
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -0,0 +1,68 @@
# This workflow is intentionally disabled while we're still working on it
# It's close to ready, but a race condition needs to be fixed with
# API server and Vespa startup, and it needs to have a way to build/test against
# local containers
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
jobs:
lint-test:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
- name: Run chart-testing (list-changed)
id: list-changed
run: |
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
- name: Run chart-testing (lint)
# if: steps.list-changed.outputs.changed == 'true'
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
# if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
# if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --config ct.yaml
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -18,14 +18,6 @@ env:
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
jobs:
connectors-check:

View File

@@ -15,7 +15,7 @@ env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
model-check:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]

1
.gitignore vendored
View File

@@ -7,4 +7,3 @@
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
/web/test-results/

View File

@@ -6,69 +6,19 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": [
"--- Individual ---"
],
"presentation": {
"group": "1",
}
},
{
"name": "Run All Danswer Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat",
],
"presentation": {
"group": "1",
}
},
{
"name": "Web / Model / API",
"configurations": [
"Web Server",
"Model Server",
"API Server",
],
"presentation": {
"group": "1",
}
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat"
],
"presentation": {
"group": "1",
}
}
"Indexing",
"Background Jobs",
"Slack Bot"
]
}
],
"configurations": [
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
@@ -79,11 +29,7 @@
"runtimeArgs": [
"run", "dev"
],
"presentation": {
"group": "2",
},
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
"console": "integratedTerminal"
},
{
"name": "Model Server",
@@ -102,11 +48,7 @@
"--reload",
"--port",
"9000"
],
"presentation": {
"group": "2",
},
"consoleTitle": "Model Server Console"
]
},
{
"name": "API Server",
@@ -126,13 +68,43 @@
"--reload",
"--port",
"8080"
],
"presentation": {
"group": "2",
},
"consoleTitle": "API Server Console"
]
},
// For the listener to access the Slack API,
{
"name": "Indexing",
"consoleName": "Indexing",
"type": "debugpy",
"request": "launch",
"program": "danswer/background/update.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
{
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"--no-indexing"
]
},
// For the listner to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
@@ -146,151 +118,7 @@
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2",
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery primary Console"
},
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery beat Console"
}
},
{
"name": "Pytest",
@@ -309,22 +137,8 @@
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2",
},
"consoleTitle": "Pytest Console"
]
},
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
@@ -333,27 +147,7 @@
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true,
"presentation": {
"group": "3",
},
},
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
},
"stopOnEntry": true
}
]
}

View File

@@ -32,7 +32,7 @@ To contribute to this project, please follow the
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
See the [Formatting and Linting](#-formatting-and-linting) section for how to run these checks locally.
### Getting Help 🙋

View File

@@ -1,5 +1,4 @@
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
@@ -12,7 +11,7 @@
<a href="https://docs.danswer.dev/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -69,7 +68,7 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
## 🚧 Roadmap
* Chat/Prompt sharing with specific teammates and user groups.
* Multimodal model support, chat with images, video etc.
* Multi-Model model support, chat with images, video etc.
* Choosing between LLMs and parameters during chat session.
* Tool calling and agent configurations options.
* Organizational understanding and ability to locate and suggest experts from your team.
@@ -128,19 +127,3 @@ To try the Danswer Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=danswer-ai/danswer&type=Date)](https://star-history.com/#danswer-ai/danswer&Date)
## ✨Contributors
<a href="https://github.com/danswer-ai/danswer/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>

View File

@@ -12,6 +12,7 @@ ARG DANSWER_VERSION=0.8-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ARG CA_CERT_CONTENT=""
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
@@ -38,6 +39,15 @@ RUN apt-get update && \
apt-get clean
# Conditionally write the CA certificate and update certificates
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
echo "Adding custom CA certificate"; \
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
update-ca-certificates; \
else \
echo "No custom CA certificate provided"; \
fi
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
@@ -73,11 +83,11 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \

View File

@@ -1,5 +1,5 @@
from sqlalchemy.engine.base import Connection
from typing import Literal
from typing import Any
import asyncio
from logging.config import fileConfig
import logging
@@ -8,14 +8,12 @@ from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import SchemaItem
from shared_configs.configs import MULTI_TENANT
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from danswer.db.engine import get_all_tenant_ids
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from danswer.background.celery.celery_app import get_all_tenant_ids
# Alembic Config object
config = context.config
@@ -36,18 +34,7 @@ logger = logging.getLogger(__name__)
def include_object(
object: SchemaItem,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool,
compare_to: SchemaItem | None,
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
) -> bool:
"""
Determines whether a database object should be included in migrations.
@@ -70,15 +57,11 @@ def get_schema_options() -> tuple[str, bool, bool]:
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key.strip()] = value.strip()
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
schema_name = x_args.get("schema", "public")
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
if (
MULTI_TENANT
and schema_name == POSTGRES_DEFAULT_SCHEMA
and not upgrade_all_tenants
):
if MULTI_TENANT and schema_name == "public":
raise ValueError(
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
"Please specify a tenant-specific schema."

View File

@@ -1,59 +0,0 @@
"""display custom llm models
Revision ID: 177de57c21c9
Revises: 4ee1287bd26a
Create Date: 2024-11-21 11:49:04.488677
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import and_
revision = "177de57c21c9"
down_revision = "4ee1287bd26a"
branch_labels = None
depends_on = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
llm_provider = sa.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("provider", sa.String),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
)
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
providers_to_update = sa.select(
llm_provider.c.id,
llm_provider.c.model_names,
llm_provider.c.display_model_names,
).where(
and_(
~llm_provider.c.provider.in_(excluded_providers),
llm_provider.c.model_names.isnot(None),
)
)
results = conn.execute(providers_to_update).fetchall()
for provider_id, model_names, display_model_names in results:
if display_model_names is None:
display_model_names = []
combined_model_names = list(set(display_model_names + model_names))
update_stmt = (
llm_provider.update()
.where(llm_provider.c.id == provider_id)
.values(display_model_names=combined_model_names)
)
conn.execute(update_stmt)
def downgrade() -> None:
pass

View File

@@ -1,68 +0,0 @@
"""default chosen assistants to none
Revision ID: 26b931506ecb
Revises: 2daa494a0851
Create Date: 2024-11-12 13:23:29.858995
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "26b931506ecb"
down_revision = "2daa494a0851"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_new =
CASE
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
)
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"chosen_assistants_old",
postgresql.JSONB(),
nullable=False,
server_default="[-2, -1, 0]",
),
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_old =
CASE
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
)

View File

@@ -1,30 +0,0 @@
"""add-group-sync-time
Revision ID: 2daa494a0851
Revises: c0fd6e4da83a
Create Date: 2024-11-11 10:57:22.991157
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2daa494a0851"
down_revision = "c0fd6e4da83a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"last_time_external_group_sync",
sa.DateTime(timezone=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "last_time_external_group_sync")

View File

@@ -1,50 +0,0 @@
"""single tool call per message
Revision ID: 33cb72ea4d80
Revises: 5b29123cd710
Create Date: 2024-11-01 12:51:01.535003
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "33cb72ea4d80"
down_revision = "5b29123cd710"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Step 1: Delete extraneous ToolCall entries
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
op.execute(
sa.text(
"""
DELETE FROM tool_call
WHERE id NOT IN (
SELECT MIN(id)
FROM tool_call
WHERE message_id IS NOT NULL
GROUP BY message_id
);
"""
)
)
# Step 2: Add a unique constraint on message_id
op.create_unique_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
columns=["message_id"],
)
def downgrade() -> None:
# Step 1: Drop the unique constraint on message_id
op.drop_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
type_="unique",
)

View File

@@ -1,45 +0,0 @@
"""add persona categories
Revision ID: 47e5bef3a1d7
Revises: dfbe9e93d3c7
Create Date: 2024-11-05 18:55:02.221064
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "47e5bef3a1d7"
down_revision = "dfbe9e93d3c7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the persona_category table
op.create_table(
"persona_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add category_id to persona table
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"fk_persona_category",
"persona",
"persona_category",
["category_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
op.drop_column("persona", "category_id")
op.drop_table("persona_category")

View File

@@ -1,280 +0,0 @@
"""add_multiple_slack_bot_support
Revision ID: 4ee1287bd26a
Revises: 47e5bef3a1d7
Create Date: 2024-11-06 13:15:53.302644
"""
import logging
from typing import cast
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
from danswer.key_value_store.factory import get_kv_store
from danswer.db.models import SlackBot
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "4ee1287bd26a"
down_revision = "47e5bef3a1d7"
branch_labels: None = None
depends_on: None = None
# Configure logging
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
def upgrade() -> None:
logger.info(f"{revision}: create_table: slack_bot")
# Create new slack_bot table
op.create_table(
"slack_bot",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("enabled", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("bot_token", sa.LargeBinary(), nullable=False),
sa.Column("app_token", sa.LargeBinary(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("bot_token"),
sa.UniqueConstraint("app_token"),
)
# # Create new slack_channel_config table
op.create_table(
"slack_channel_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("slack_bot_id", sa.Integer(), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column(
"enable_auto_filters", sa.Boolean(), nullable=False, server_default="false"
),
sa.ForeignKeyConstraint(
["slack_bot_id"],
["slack_bot.id"],
),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Handle existing Slack bot tokens first
logger.info(f"{revision}: Checking for existing Slack bot.")
bot_token = None
app_token = None
first_row_id = None
try:
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
except Exception:
logger.warning("No existing Slack bot tokens found.")
tokens = {}
bot_token = tokens.get("bot_token")
app_token = tokens.get("app_token")
if bot_token and app_token:
logger.info(f"{revision}: Found bot and app tokens.")
session = Session(bind=op.get_bind())
new_slack_bot = SlackBot(
name="Slack Bot (Migrated)",
enabled=True,
bot_token=bot_token,
app_token=app_token,
)
session.add(new_slack_bot)
session.commit()
first_row_id = new_slack_bot.id
# Create a default bot if none exists
# This is in case there are no slack tokens but there are channels configured
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)
# Get the bot ID to use (either from existing migration or newly created)
bot_id_query = sa.text(
"""
SELECT COALESCE(
:first_row_id,
(SELECT id FROM slack_bot ORDER BY id ASC LIMIT 1)
) as bot_id;
"""
)
result = op.get_bind().execute(bot_id_query, {"first_row_id": first_row_id})
bot_id = result.scalar()
# CTE (Common Table Expression) that transforms the old slack_bot_config table data
# This splits up the channel_names into their own rows
channel_names_cte = """
WITH channel_names AS (
SELECT
sbc.id as config_id,
sbc.persona_id,
sbc.response_type,
sbc.enable_auto_filters,
jsonb_array_elements_text(sbc.channel_config->'channel_names') as channel_name,
sbc.channel_config->>'respond_tag_only' as respond_tag_only,
sbc.channel_config->>'respond_to_bots' as respond_to_bots,
sbc.channel_config->'respond_member_group_list' as respond_member_group_list,
sbc.channel_config->'answer_filters' as answer_filters,
sbc.channel_config->'follow_up_tags' as follow_up_tags
FROM slack_bot_config sbc
)
"""
# Insert the channel names into the new slack_channel_config table
insert_statement = """
INSERT INTO slack_channel_config (
slack_bot_id,
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT
:bot_id,
channel_name.persona_id,
jsonb_build_object(
'channel_name', channel_name.channel_name,
'respond_tag_only',
COALESCE((channel_name.respond_tag_only)::boolean, false),
'respond_to_bots',
COALESCE((channel_name.respond_to_bots)::boolean, false),
'respond_member_group_list',
COALESCE(channel_name.respond_member_group_list, '[]'::jsonb),
'answer_filters',
COALESCE(channel_name.answer_filters, '[]'::jsonb),
'follow_up_tags',
COALESCE(channel_name.follow_up_tags, '[]'::jsonb)
),
channel_name.response_type,
channel_name.enable_auto_filters
FROM channel_names channel_name;
"""
op.execute(sa.text(channel_names_cte + insert_statement).bindparams(bot_id=bot_id))
# Clean up old tokens if they existed
try:
if bot_token and app_token:
logger.info(f"{revision}: Removing old bot and app tokens.")
get_kv_store().delete("slack_bot_tokens_config_key")
except Exception:
logger.warning("tried to delete tokens in dynamic config but failed")
# Rename the table
op.rename_table(
"slack_bot_config__standard_answer_category",
"slack_channel_config__standard_answer_category",
)
# Rename the column
op.alter_column(
"slack_channel_config__standard_answer_category",
"slack_bot_config_id",
new_column_name="slack_channel_config_id",
)
# Drop the table with CASCADE to handle dependent objects
op.execute("DROP TABLE slack_bot_config CASCADE")
logger.info(f"{revision}: Migration complete.")
def downgrade() -> None:
# Recreate the old slack_bot_config table
op.create_table(
"slack_bot_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column("enable_auto_filters", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Migrate data back to the old format
# Group by persona_id to combine channel names back into arrays
op.execute(
sa.text(
"""
INSERT INTO slack_bot_config (
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT DISTINCT ON (persona_id)
persona_id,
jsonb_build_object(
'channel_names', (
SELECT jsonb_agg(c.channel_config->>'channel_name')
FROM slack_channel_config c
WHERE c.persona_id = scc.persona_id
),
'respond_tag_only', (channel_config->>'respond_tag_only')::boolean,
'respond_to_bots', (channel_config->>'respond_to_bots')::boolean,
'respond_member_group_list', channel_config->'respond_member_group_list',
'answer_filters', channel_config->'answer_filters',
'follow_up_tags', channel_config->'follow_up_tags'
),
response_type,
enable_auto_filters
FROM slack_channel_config scc
WHERE persona_id IS NOT NULL;
"""
)
)
# Rename the table back
op.rename_table(
"slack_channel_config__standard_answer_category",
"slack_bot_config__standard_answer_category",
)
# Rename the column back
op.alter_column(
"slack_bot_config__standard_answer_category",
"slack_channel_config_id",
new_column_name="slack_bot_config_id",
)
# Try to save the first bot's tokens back to KV store
try:
first_bot = (
op.get_bind()
.execute(
sa.text(
"SELECT bot_token, app_token FROM slack_bot ORDER BY id LIMIT 1"
)
)
.first()
)
if first_bot and first_bot.bot_token and first_bot.app_token:
tokens = {
"bot_token": first_bot.bot_token,
"app_token": first_bot.app_token,
}
get_kv_store().store("slack_bot_tokens_config_key", tokens)
except Exception:
logger.warning("Failed to save tokens back to KV store")
# Drop the new tables in reverse order
op.drop_table("slack_channel_config")
op.drop_table("slack_bot")

View File

@@ -1,70 +0,0 @@
"""nullable search settings for historic index attempts
Revision ID: 5b29123cd710
Revises: 949b4a92a401
Create Date: 2024-10-30 19:37:59.630704
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5b29123cd710"
down_revision = "949b4a92a401"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be nullable
op.alter_column(
"index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True
)
# Add back the foreign key with ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
# Warning: This will delete all index attempts that don't have search settings
op.execute(
"""
DELETE FROM index_attempt
WHERE search_settings_id IS NULL
"""
)
# Drop foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be not nullable
op.alter_column(
"index_attempt",
"search_settings_id",
existing_type=sa.INTEGER(),
nullable=False,
)
# Add back the foreign key without ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
)

View File

@@ -1,9 +1,7 @@
"""Migrate chat_session and chat_message tables to use UUID primary keys
"""
Revision ID: 6756efa39ada
Revises: 5d12a446f5c0
Create Date: 2024-10-15 17:47:44.108537
"""
from alembic import op
import sqlalchemy as sa
@@ -14,6 +12,8 @@ branch_labels = None
depends_on = None
"""
Migrate chat_session and chat_message tables to use UUID primary keys.
This script:
1. Adds UUID columns to chat_session and chat_message
2. Populates new columns with UUIDs

View File

@@ -1,45 +0,0 @@
"""remove default bot
Revision ID: 6d562f86c78b
Revises: 177de57c21c9
Create Date: 2024-11-22 11:51:29.331336
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6d562f86c78b"
down_revision = "177de57c21c9"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
DELETE FROM slack_bot
WHERE name = 'Default Bot'
AND bot_token = ''
AND app_token = ''
AND NOT EXISTS (
SELECT 1 FROM slack_channel_config
WHERE slack_channel_config.slack_bot_id = slack_bot.id
)
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)

View File

@@ -9,8 +9,8 @@ from alembic import op
import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -1,35 +0,0 @@
"""add web ui option to slack config
Revision ID: 93560ba1b118
Revises: 6d562f86c78b
Create Date: 2024-11-24 06:36:17.490612
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "93560ba1b118"
down_revision = "6d562f86c78b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add show_continue_in_web_ui with default False to all existing channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
WHERE NOT channel_config ? 'show_continue_in_web_ui'
"""
)
def downgrade() -> None:
# Remove show_continue_in_web_ui from all channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config - 'show_continue_in_web_ui'
"""
)

View File

@@ -1,72 +0,0 @@
"""remove rt
Revision ID: 949b4a92a401
Revises: 1b10e1fda030
Create Date: 2024-10-26 13:06:06.937969
"""
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
# Import your models and constants
from danswer.db.models import (
Connector,
ConnectorCredentialPair,
Credential,
IndexAttempt,
)
# revision identifiers, used by Alembic.
revision = "949b4a92a401"
down_revision = "1b10e1fda030"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Deletes all RequestTracker connectors and associated data
bind = op.get_bind()
session = Session(bind=bind)
# Get connectors using raw SQL
result = bind.execute(
text("SELECT id FROM connector WHERE source = 'requesttracker'")
)
connector_ids = [row[0] for row in result]
if connector_ids:
cc_pairs_to_delete = (
session.query(ConnectorCredentialPair)
.filter(ConnectorCredentialPair.connector_id.in_(connector_ids))
.all()
)
cc_pair_ids = [cc_pair.id for cc_pair in cc_pairs_to_delete]
if cc_pair_ids:
session.query(IndexAttempt).filter(
IndexAttempt.connector_credential_pair_id.in_(cc_pair_ids)
).delete(synchronize_session=False)
session.query(ConnectorCredentialPair).filter(
ConnectorCredentialPair.id.in_(cc_pair_ids)
).delete(synchronize_session=False)
credential_ids = [cc_pair.credential_id for cc_pair in cc_pairs_to_delete]
if credential_ids:
session.query(Credential).filter(Credential.id.in_(credential_ids)).delete(
synchronize_session=False
)
session.query(Connector).filter(Connector.id.in_(connector_ids)).delete(
synchronize_session=False
)
session.commit()
def downgrade() -> None:
# No-op downgrade as we cannot restore deleted data
pass

View File

@@ -1,30 +0,0 @@
"""add creator to cc pair
Revision ID: 9cf5c00f72fe
Revises: 26b931506ecb
Create Date: 2024-11-12 15:16:42.682902
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9cf5c00f72fe"
down_revision = "26b931506ecb"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"creator_id",
sa.UUID(as_uuid=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "creator_id")

View File

@@ -1,36 +0,0 @@
"""Combine Search and Chat
Revision ID: 9f696734098f
Revises: a8c2065484e6
Create Date: 2024-11-27 15:32:19.694972
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9f696734098f"
down_revision = "a8c2065484e6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("chat_session", "description", nullable=True)
op.drop_column("chat_session", "one_shot")
op.drop_column("slack_channel_config", "response_type")
def downgrade() -> None:
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
op.alter_column("chat_session", "description", nullable=False)
op.add_column(
"chat_session",
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
)
op.add_column(
"slack_channel_config",
sa.Column(
"response_type", sa.String(), nullable=False, server_default="citations"
),
)

View File

@@ -1,27 +0,0 @@
"""add auto scroll to user model
Revision ID: a8c2065484e6
Revises: abe7378b8217
Create Date: 2024-11-22 17:34:09.690295
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a8c2065484e6"
down_revision = "abe7378b8217"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
)
def downgrade() -> None:
op.drop_column("user", "auto_scroll")

View File

@@ -1,30 +0,0 @@
"""add indexing trigger to cc_pair
Revision ID: abe7378b8217
Revises: 6d562f86c78b
Create Date: 2024-11-26 19:09:53.481171
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "abe7378b8217"
down_revision = "93560ba1b118"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"indexing_trigger",
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "indexing_trigger")

View File

@@ -31,12 +31,6 @@ def upgrade() -> None:
def downgrade() -> None:
# First, update any null values to a default value
op.execute(
"UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL"
)
# Then, make the column non-nullable
op.alter_column(
"connector_credential_pair",
"last_attempt_status",

View File

@@ -288,15 +288,6 @@ def upgrade() -> None:
def downgrade() -> None:
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
# below
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM chat_message__search_doc")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM chat_message")
op.execute("DELETE FROM chat_session")
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)

View File

@@ -1,48 +0,0 @@
"""remove description from starter messages
Revision ID: b72ed7a5db0e
Revises: 33cb72ea4d80
Create Date: 2024-11-03 15:55:28.944408
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b72ed7a5db0e"
down_revision = "33cb72ea4d80"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
UPDATE persona
SET starter_messages = (
SELECT jsonb_agg(elem - 'description')
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
UPDATE persona
SET starter_messages = (
SELECT jsonb_agg(elem || '{"description": ""}')
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)

View File

@@ -1,29 +0,0 @@
"""add recent assistants
Revision ID: c0fd6e4da83a
Revises: b72ed7a5db0e
Create Date: 2024-11-03 17:28:54.916618
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c0fd6e4da83a"
down_revision = "b72ed7a5db0e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)
def downgrade() -> None:
op.drop_column("user", "recent_assistants")

View File

@@ -23,56 +23,6 @@ def upgrade() -> None:
def downgrade() -> None:
# Delete chat messages and feedback first since they reference chat sessions
# Get chat messages from sessions with null persona_id
chat_messages_query = """
SELECT id
FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
# Delete dependent records first
op.execute(
f"""
DELETE FROM document_retrieval_feedback
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
op.execute(
f"""
DELETE FROM chat_message__search_doc
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
# Delete chat messages
op.execute(
"""
DELETE FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
)
# Now we can safely delete the chat sessions
op.execute(
"""
DELETE FROM chat_session
WHERE persona_id IS NULL
"""
)
op.alter_column(
"chat_session",
"persona_id",

View File

@@ -1,42 +0,0 @@
"""extended_role_for_non_web
Revision ID: dfbe9e93d3c7
Revises: 9cf5c00f72fe
Create Date: 2024-11-16 07:54:18.727906
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "dfbe9e93d3c7"
down_revision = "9cf5c00f72fe"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE "user"
SET role = 'EXT_PERM_USER'
WHERE has_web_login = false
"""
)
op.drop_column("user", "has_web_login")
def downgrade() -> None:
op.add_column(
"user",
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
)
op.execute(
"""
UPDATE "user"
SET has_web_login = false,
role = 'BASIC'
WHERE role IN ('SLACK_USER', 'EXT_PERM_USER')
"""
)

View File

@@ -1,6 +1,5 @@
import asyncio
from logging.config import fileConfig
from typing import Literal
from sqlalchemy import pool
from sqlalchemy.engine import Connection
@@ -38,15 +37,8 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:

View File

@@ -1,3 +1,3 @@
import os
__version__ = os.environ.get("DANSWER_VERSION", "") or "Development"
__version__ = os.environ.get("DANSWER_VERSION", "") or "0.3-dev"

View File

@@ -16,46 +16,6 @@ class ExternalAccess:
is_public: bool
@dataclass(frozen=True)
class DocExternalAccess:
"""
This is just a class to wrap the external access and the document ID
together. It's used for syncing document permissions to Redis.
"""
external_access: ExternalAccess
# The document ID
doc_id: str
def to_dict(self) -> dict:
return {
"external_access": {
"external_user_emails": list(self.external_access.external_user_emails),
"external_user_group_ids": list(
self.external_access.external_user_group_ids
),
"is_public": self.external_access.is_public,
},
"doc_id": self.doc_id,
}
@classmethod
def from_dict(cls, data: dict) -> "DocExternalAccess":
external_access = ExternalAccess(
external_user_emails=set(
data["external_access"].get("external_user_emails", [])
),
external_user_group_ids=set(
data["external_access"].get("external_user_group_ids", [])
),
is_public=data["external_access"]["is_public"],
)
return cls(
external_access=external_access,
doc_id=data["doc_id"],
)
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin
@@ -110,12 +70,3 @@ class DocumentAccess(ExternalAccess):
user_groups=set(user_groups),
is_public=is_public,
)
default_public_access = DocumentAccess(
external_user_emails=set(),
external_user_group_ids=set(),
user_emails=set(),
user_groups=set(),
is_public=True,
)

View File

@@ -1,42 +0,0 @@
from collections.abc import Hashable
from typing import Union
from langgraph.types import Send
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.primary_graph.states import VerifierState
def sub_continue_to_verifier(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
# Routes each de-douped retrieved doc to the verifier step - in parallel
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_verifier",
VerifierState(
document=doc,
#question=state["original_question"],
question=state["sub_question_str"],
graph_start_time=state["graph_start_time"],
),
)
for doc in state["sub_question_deduped_retrieval_docs"]
]
def sub_continue_to_retrieval(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
rewritten_queries = state["sub_question_search_queries"].rewritten_queries + [state["sub_question_str"]]
return [
Send(
"sub_custom_retrieve",
RetrieverState(
rewritten_query=query,
graph_start_time=state["graph_start_time"],
),
)
for query in rewritten_queries
]

View File

@@ -1,132 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_retrieval
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_verifier
from danswer.agent_search.core_qa_graph.nodes.combine_retrieved_docs import (
sub_combine_retrieved_docs,
)
from danswer.agent_search.core_qa_graph.nodes.custom_retrieve import (
sub_custom_retrieve,
)
from danswer.agent_search.core_qa_graph.nodes.dummy import sub_dummy
from danswer.agent_search.core_qa_graph.nodes.final_format import (
sub_final_format,
)
from danswer.agent_search.core_qa_graph.nodes.generate import sub_generate
from danswer.agent_search.core_qa_graph.nodes.qa_check import sub_qa_check
from danswer.agent_search.core_qa_graph.nodes.rewrite import sub_rewrite
from danswer.agent_search.core_qa_graph.nodes.verifier import sub_verifier
from danswer.agent_search.core_qa_graph.states import BaseQAOutputState
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.core_qa_graph.states import CoreQAInputState
def build_core_qa_graph() -> StateGraph:
sub_answers_initial = StateGraph(
state_schema=BaseQAState,
output=BaseQAOutputState,
)
### Add nodes ###
sub_answers_initial.add_node(node="sub_dummy", action=sub_dummy)
sub_answers_initial.add_node(node="sub_rewrite", action=sub_rewrite)
sub_answers_initial.add_node(
node="sub_custom_retrieve",
action=sub_custom_retrieve,
)
sub_answers_initial.add_node(
node="sub_combine_retrieved_docs",
action=sub_combine_retrieved_docs,
)
sub_answers_initial.add_node(
node="sub_verifier",
action=sub_verifier,
)
sub_answers_initial.add_node(
node="sub_generate",
action=sub_generate,
)
sub_answers_initial.add_node(
node="sub_qa_check",
action=sub_qa_check,
)
sub_answers_initial.add_node(
node="sub_final_format",
action=sub_final_format,
)
### Add edges ###
sub_answers_initial.add_edge(START, "sub_dummy")
sub_answers_initial.add_edge("sub_dummy", "sub_rewrite")
sub_answers_initial.add_conditional_edges(
source="sub_rewrite",
path=sub_continue_to_retrieval,
)
sub_answers_initial.add_edge(
start_key="sub_custom_retrieve",
end_key="sub_combine_retrieved_docs",
)
sub_answers_initial.add_conditional_edges(
source="sub_combine_retrieved_docs",
path=sub_continue_to_verifier,
path_map=["sub_verifier"],
)
sub_answers_initial.add_edge(
start_key="sub_verifier",
end_key="sub_generate",
)
sub_answers_initial.add_edge(
start_key="sub_generate",
end_key="sub_qa_check",
)
sub_answers_initial.add_edge(
start_key="sub_qa_check",
end_key="sub_final_format",
)
sub_answers_initial.add_edge(
start_key="sub_final_format",
end_key=END,
)
# sub_answers_graph = sub_answers_initial.compile()
return sub_answers_initial
if __name__ == "__main__":
# q = "Whose music is kind of hard to easily enjoy?"
# q = "What is voice leading?"
# q = "What are the types of motions in music?"
# q = "What are key elements of music theory?"
# q = "How can I best understand music theory using voice leading?"
q = "What makes good music?"
# q = "types of motions in music"
# q = "What is the relationship between music and physics?"
# q = "Can you compare various grunge styles?"
# q = "Why is quantum gravity so hard?"
inputs = CoreQAInputState(
original_question=q,
sub_question_str=q,
)
sub_answers_graph = build_core_qa_graph()
compiled_sub_answers = sub_answers_graph.compile()
output = compiled_sub_answers.invoke(inputs)
print("\nOUTPUT:")
print(output.keys())
for key, value in output.items():
if key in [
"sub_question_answer",
"sub_question_str",
"sub_qas",
"initial_sub_qas",
"sub_question_answer",
]:
print(f"{key}: {value}")

View File

@@ -1,36 +0,0 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
def sub_combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]:
"""
Dedupe the retrieved docs.
"""
node_start_time = datetime.now()
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
dedupe_docs: list[InferenceSection] = []
for base_retrieval_doc in sub_question_base_retrieval_docs:
if not any(
base_retrieval_doc.center_chunk.chunk_id == doc.center_chunk.chunk_id
for doc in dedupe_docs
):
dedupe_docs.append(base_retrieval_doc)
print(f"Number of deduped docs: {len(dedupe_docs)}")
return {
"sub_question_deduped_retrieval_docs": dedupe_docs,
"log_messages": generate_log_message(
message="sub - combine_retrieved_docs (dedupe)",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,66 +0,0 @@
import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
from danswer.db.engine import get_session_context_manager
from danswer.llm.factory import get_default_llms
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE SUB---")
node_start_time = datetime.datetime.now()
rewritten_query = state["rewritten_query"]
# Retrieval
# TODO: add the actual retrieval, probably from search_tool.run()
documents: list[InferenceSection] = []
llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
documents = SearchPipeline(
search_request=SearchRequest(
query=rewritten_query,
),
user=None,
llm=llm,
fast_llm=fast_llm,
db_session=db_session,
)
reranked_docs = documents.reranked_sections
# initial metric to measure fit TODO: implement metric properly
top_1_score = reranked_docs[0].center_chunk.score
top_5_score = sum([doc.center_chunk.score for doc in reranked_docs[:5]]) / 5
top_10_score = sum([doc.center_chunk.score for doc in reranked_docs[:10]]) / 10
fit_score = 1/3 * (top_1_score + top_5_score + top_10_score)
chunk_ids = {'query': rewritten_query,
'chunk_ids': [doc.center_chunk.chunk_id for doc in reranked_docs]}
return {
"sub_question_base_retrieval_docs": reranked_docs,
"sub_chunk_ids": [chunk_ids],
"log_messages": generate_log_message(
message=f"sub - custom_retrieve, fit_score: {fit_score}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,24 +0,0 @@
import datetime
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
"""
Dummy step
"""
print("---Sub Dummy---")
node_start_time = datetime.datetime.now()
return {
"graph_start_time": node_start_time,
"log_messages": generate_log_message(
message="sub - dummy",
node_start_time=node_start_time,
graph_start_time=node_start_time,
),
}

View File

@@ -1,22 +0,0 @@
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
def sub_final_format(state: BaseQAState) -> dict[str, Any]:
"""
Create the final output for the QA subgraph
"""
print("---BASE FINAL FORMAT---")
return {
"sub_qas": [
{
"sub_question": state["sub_question_str"],
"sub_answer": state["sub_question_answer"],
"sub_answer_check": state["sub_question_answer_check"],
}
],
"log_messages": state["log_messages"],
}

View File

@@ -1,91 +0,0 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_generate(state: BaseQAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---GENERATE---")
# Create sub-query results
verified_chunks = [chunk.center_chunk.chunk_id for chunk in state["sub_question_verified_retrieval_docs"]]
result_dict = {}
chunk_id_dicts = state["sub_chunk_ids"]
expanded_chunks = []
original_chunks = []
for chunk_id_dict in chunk_id_dicts:
sub_question = chunk_id_dict['query']
verified_sq_chunks = [chunk_id for chunk_id in chunk_id_dict['chunk_ids'] if chunk_id in verified_chunks]
if sub_question != state["original_question"]:
expanded_chunks += verified_sq_chunks
else:
result_dict['ORIGINAL'] = len(verified_sq_chunks)
original_chunks += verified_sq_chunks
result_dict[sub_question[:30]] = len(verified_sq_chunks)
expansion_chunks = set(expanded_chunks)
num_expansion_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id in verified_chunks])
num_original_relevant_chunks = len(original_chunks)
num_missed_relevant_chunks = sum([1 for chunk_id in original_chunks if chunk_id not in expansion_chunks])
num_gained_relevant_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id not in original_chunks])
result_dict['expansion_chunks'] = num_expansion_chunks
print(result_dict)
node_start_time = datetime.now()
question = state["sub_question_str"]
docs = state["sub_question_verified_retrieval_docs"]
print(f"Number of verified retrieval docs: {len(docs)}")
# Only take the top 10 docs.
# TODO: Make this dynamic or use config param?
top_10_docs = docs[-10:]
msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(top_10_docs))
)
]
# Grader
_, fast_llm = get_default_llms()
response = list(
fast_llm.stream(
prompt=msg,
# structured_response_format=None,
)
)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
return {
"sub_question_answer": answer_str,
"log_messages": generate_log_message(
message="base - generate",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,51 +0,0 @@
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_qa_check(state: BaseQAState) -> dict[str, Any]:
"""
Check if the sub-question answer is satisfactory.
Args:
state: The current SubQAState containing the sub-question and its answer
Returns:
dict containing the check result and log message
"""
node_start_time = datetime.datetime.now()
msg = [
HumanMessage(
content=BASE_CHECK_PROMPT.format(
question=state["sub_question_str"],
base_answer=state["sub_question_answer"],
)
)
]
_, fast_llm = get_default_llms()
response = list(
fast_llm.stream(
prompt=msg,
# structured_response_format=None,
)
)
response_str = merge_message_runs(response, chunk_separator="")[0].content
return {
"sub_question_answer_check": response_str,
"base_answer_messages": generate_log_message(
message="sub - qa_check",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,74 +0,0 @@
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.agent_search.shared_graph_utils.prompts import (
REWRITE_PROMPT_MULTI_ORIGINAL,
)
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_rewrite(state: BaseQAState) -> dict[str, Any]:
"""
Transform the initial question into more suitable search queries.
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---SUB TRANSFORM QUERY---")
node_start_time = datetime.datetime.now()
# messages = state["base_answer_messages"]
question = state["sub_question_str"]
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
)
]
"""
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI.format(question=question),
)
]
"""
_, fast_llm = get_default_llms()
llm_response_list = list(
fast_llm.stream(
prompt=msg,
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
# structured_response_format=RewrittenQueries.model_json_schema(),
)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
print(f"llm_response: {llm_response}")
rewritten_queries = llm_response.split("--")
# rewritten_queries = [llm_response.split("\n")[0]]
print(f"rewritten_queries: {rewritten_queries}")
rewritten_queries = RewrittenQueries(rewritten_queries=rewritten_queries)
return {
"sub_question_search_queries": rewritten_queries,
"log_messages": generate_log_message(
message="sub - rewrite",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,64 +0,0 @@
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.primary_graph.states import VerifierState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_verifier(state: VerifierState) -> dict[str, Any]:
"""
Check whether the document is relevant for the original user question
Args:
state (VerifierState): The current state
Returns:
dict: ict: The updated state with the final decision
"""
# print("---VERIFY QUTPUT---")
node_start_time = datetime.datetime.now()
question = state["question"]
document_content = state["document"].combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
# Grader
llm, fast_llm = get_default_llms()
response = list(
llm.stream(
prompt=msg,
# structured_response_format=BinaryDecision.model_json_schema(),
)
)
response_string = merge_message_runs(response, chunk_separator="")[0].content
# Convert string response to proper dictionary format
decision_dict = {"decision": response_string.lower()}
formatted_response = BinaryDecision.model_validate(decision_dict)
print(f"Verification end time: {datetime.datetime.now()}")
return {
"sub_question_verified_retrieval_docs": [state["document"]]
if formatted_response.decision == "yes"
else [],
"log_messages": generate_log_message(
message=f"sub - verifier: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,90 +0,0 @@
import operator
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated
from typing import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.context.search.models import InferenceSection
from danswer.llm.interfaces import LLM
class SubQuestionRetrieverState(TypedDict):
# The state for the parallel Retrievers. They each need to see only one query
sub_question_rewritten_query: str
class SubQuestionVerifierState(TypedDict):
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
sub_question_document: InferenceSection
sub_question: str
class CoreQAInputState(TypedDict):
sub_question_str: str
original_question: str
class BaseQAState(TypedDict):
# The 'core SubQuestion' state.
original_question: str
graph_start_time: datetime
# start time for parallel initial sub-questionn thread
sub_query_start_time: datetime
sub_question_rewritten_queries: list[str]
sub_question_str: str
sub_question_search_queries: RewrittenQueries
sub_question_nr: int
sub_chunk_ids: Annotated[Sequence[dict], operator.add]
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]
sub_qas: Annotated[Sequence[dict], operator.add]
# Answers sent back to core
initial_sub_qas: Annotated[Sequence[dict], operator.add]
primary_llm: LLM
fast_llm: LLM
class BaseQAOutputState(TypedDict):
# The 'SubQuestion' output state. Removes all the intermediate states
sub_question_rewritten_queries: list[str]
sub_question_str: str
sub_question_search_queries: list[str]
sub_question_nr: int
# Answers sent back to core
sub_qas: Annotated[Sequence[dict], operator.add]
# Answers sent back to core
initial_sub_qas: Annotated[Sequence[dict], operator.add]
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]

View File

@@ -1,46 +0,0 @@
from collections.abc import Hashable
from typing import Union
from langgraph.types import Send
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.primary_graph.states import VerifierState
def sub_continue_to_verifier(state: ResearchQAState) -> Union[Hashable, list[Hashable]]:
# Routes each de-douped retrieved doc to the verifier step - in parallel
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_verifier",
VerifierState(
document=doc,
question=state["sub_question"],
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
graph_start_time=state["graph_start_time"],
),
)
for doc in state["sub_question_base_retrieval_docs"]
]
def sub_continue_to_retrieval(
state: ResearchQAState,
) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_custom_retrieve",
RetrieverState(
rewritten_query=query,
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
graph_start_time=state["graph_start_time"],
),
)
for query in state["sub_question_rewritten_queries"]
]

View File

@@ -1,93 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_retrieval
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_verifier
from danswer.agent_search.deep_qa_graph.nodes.combine_retrieved_docs import (
sub_combine_retrieved_docs,
)
from danswer.agent_search.deep_qa_graph.nodes.custom_retrieve import sub_custom_retrieve
from danswer.agent_search.deep_qa_graph.nodes.dummy import sub_dummy
from danswer.agent_search.deep_qa_graph.nodes.final_format import sub_final_format
from danswer.agent_search.deep_qa_graph.nodes.generate import sub_generate
from danswer.agent_search.deep_qa_graph.nodes.qa_check import sub_qa_check
from danswer.agent_search.deep_qa_graph.nodes.verifier import sub_verifier
from danswer.agent_search.deep_qa_graph.states import ResearchQAOutputState
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
def build_deep_qa_graph() -> StateGraph:
# Define the nodes we will cycle between
sub_answers = StateGraph(state_schema=ResearchQAState, output=ResearchQAOutputState)
### Add Nodes ###
# Dummy node for initial processing
sub_answers.add_node(node="sub_dummy", action=sub_dummy)
# The retrieval step
sub_answers.add_node(node="sub_custom_retrieve", action=sub_custom_retrieve)
# The dedupe step
sub_answers.add_node(
node="sub_combine_retrieved_docs", action=sub_combine_retrieved_docs
)
# Verifying retrieved information
sub_answers.add_node(node="sub_verifier", action=sub_verifier)
# Generating the response
sub_answers.add_node(node="sub_generate", action=sub_generate)
# Checking the quality of the answer
sub_answers.add_node(node="sub_qa_check", action=sub_qa_check)
# Final formatting of the response
sub_answers.add_node(node="sub_final_format", action=sub_final_format)
### Add Edges ###
# Generate multiple sub-questions
sub_answers.add_edge(start_key=START, end_key="sub_rewrite")
# For each sub-question, perform a retrieval in parallel
sub_answers.add_conditional_edges(
source="sub_rewrite",
path=sub_continue_to_retrieval,
path_map=["sub_custom_retrieve"],
)
# Combine the retrieved docs for each sub-question from the parallel retrievals
sub_answers.add_edge(
start_key="sub_custom_retrieve", end_key="sub_combine_retrieved_docs"
)
# Go over all of the combined retrieved docs and verify them against the original question
sub_answers.add_conditional_edges(
source="sub_combine_retrieved_docs",
path=sub_continue_to_verifier,
path_map=["sub_verifier"],
)
# Generate an answer for each verified retrieved doc
sub_answers.add_edge(start_key="sub_verifier", end_key="sub_generate")
# Check the quality of the answer
sub_answers.add_edge(start_key="sub_generate", end_key="sub_qa_check")
sub_answers.add_edge(start_key="sub_qa_check", end_key="sub_final_format")
sub_answers.add_edge(start_key="sub_final_format", end_key=END)
return sub_answers
if __name__ == "__main__":
# TODO: add the actual question
inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"}
sub_answers_graph = build_deep_qa_graph()
compiled_sub_answers = sub_answers_graph.compile()
output = compiled_sub_answers.invoke(inputs)
print("\nOUTPUT:")
print(output)

View File

@@ -1,31 +0,0 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]:
"""
Dedupe the retrieved docs.
"""
node_start_time = datetime.now()
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
dedupe_docs = []
for base_retrieval_doc in sub_question_base_retrieval_docs:
if base_retrieval_doc not in dedupe_docs:
dedupe_docs.append(base_retrieval_doc)
print(f"Number of deduped docs: {len(dedupe_docs)}")
return {
"sub_question_deduped_retrieval_docs": dedupe_docs,
"log_messages": generate_log_message(
message="sub - combine_retrieved_docs (dedupe)",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,33 +0,0 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE SUB---")
node_start_time = datetime.now()
# Retrieval
# TODO: add the actual retrieval, probably from search_tool.run()
documents: list[InferenceSection] = []
return {
"sub_question_base_retrieval_docs": documents,
"log_messages": generate_log_message(
message="sub - custom_retrieve",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,21 +0,0 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
"""
Dummy step
"""
print("---Sub Dummy---")
return {
"log_messages": generate_log_message(
message="sub - dummy",
node_start_time=datetime.now(),
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,31 +0,0 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_final_format(state: ResearchQAState) -> dict[str, Any]:
"""
Create the final output for the QA subgraph
"""
print("---SUB FINAL FORMAT---")
node_start_time = datetime.now()
return {
# TODO: Type this
"sub_qas": [
{
"sub_question": state["sub_question"],
"sub_answer": state["sub_question_answer"],
"sub_question_nr": state["sub_question_nr"],
"sub_answer_check": state["sub_question_answer_check"],
}
],
"log_messages": generate_log_message(
message="sub - final format",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,56 +0,0 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_generate(state: ResearchQAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---SUB GENERATE---")
node_start_time = datetime.now()
question = state["sub_question"]
docs = state["sub_question_verified_retrieval_docs"]
print(f"Number of verified retrieval docs for sub-question: {len(docs)}")
msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
)
]
# Grader
if len(docs) > 0:
model = state["fast_llm"]
response = list(
model.stream(
prompt=msg,
)
)
response_str = merge_message_runs(response, chunk_separator="")[0].content
else:
response_str = ""
return {
"sub_question_answer": response_str,
"log_messages": generate_log_message(
message="sub - generate",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,57 +0,0 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.deep_qa_graph.prompts import SUB_CHECK_PROMPT
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_qa_check(state: ResearchQAState) -> dict[str, Any]:
"""
Check whether the final output satisfies the original user question
Args:
state (messages): The current state
Returns:
dict: The updated state with the final decision
"""
print("---CHECK SUB QUTPUT---")
node_start_time = datetime.now()
sub_answer = state["sub_question_answer"]
sub_question = state["sub_question"]
msg = [
HumanMessage(
content=SUB_CHECK_PROMPT.format(
sub_question=sub_question, sub_answer=sub_answer
)
)
]
# Grader
model = state["fast_llm"]
response = list(
model.stream(
prompt=msg,
structured_response_format=BinaryDecision.model_json_schema(),
)
)
raw_response = json.loads(response[0].pretty_repr())
formatted_response = BinaryDecision.model_validate(raw_response)
return {
"sub_question_answer_check": formatted_response.decision,
"log_messages": generate_log_message(
message=f"sub - qa check: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,64 +0,0 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.interfaces import LLM
def sub_rewrite(state: ResearchQAState) -> dict[str, Any]:
"""
Transform the initial question into more suitable search queries.
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---SUB TRANSFORM QUERY---")
node_start_time = datetime.now()
question = state["sub_question"]
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI.format(question=question),
)
]
fast_llm: LLM = state["fast_llm"]
llm_response = list(
fast_llm.stream(
prompt=msg,
structured_response_format=RewrittenQueries.model_json_schema(),
)
)
# Get the rewritten queries in a defined format
rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
print(f"rewritten_queries: {rewritten_queries}")
rewritten_queries = RewrittenQueries(
rewritten_queries=[
"music hard to listen to",
"Music that is not fun or pleasant",
]
)
print(f"hardcoded rewritten_queries: {rewritten_queries}")
return {
"sub_question_rewritten_queries": rewritten_queries,
"log_messages": generate_log_message(
message="sub - rewrite",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,59 +0,0 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import VerifierState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_verifier(state: VerifierState) -> dict[str, Any]:
"""
Check whether the document is relevant for the original user question
Args:
state (VerifierState): The current state
Returns:
dict: ict: The updated state with the final decision
"""
print("---SUB VERIFY QUTPUT---")
node_start_time = datetime.now()
question = state["question"]
document_content = state["document"].combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
# Grader
model = state["fast_llm"]
response = list(
model.stream(
prompt=msg,
structured_response_format=BinaryDecision.model_json_schema(),
)
)
raw_response = json.loads(response[0].pretty_repr())
formatted_response = BinaryDecision.model_validate(raw_response)
return {
"deduped_retrieval_docs": [state["document"]]
if formatted_response.decision == "yes"
else [],
"log_messages": generate_log_message(
message=f"core - verifier: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,13 +0,0 @@
SUB_CHECK_PROMPT = """ \n
Please check whether the suggested answer seems to address the original question.
Please only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the proposed answer:
\n ------- \n
{base_answer}
\n ------- \n
Please answer with yes or no:"""

View File

@@ -1,64 +0,0 @@
import operator
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated
from typing import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from danswer.context.search.models import InferenceSection
from danswer.llm.interfaces import LLM
class ResearchQAState(TypedDict):
# The 'core SubQuestion' state.
original_question: str
graph_start_time: datetime
sub_question_rewritten_queries: list[str]
sub_question: str
sub_question_nr: int
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]
sub_qas: Annotated[Sequence[dict], operator.add]
primary_llm: LLM
fast_llm: LLM
class ResearchQAOutputState(TypedDict):
# The 'SubQuestion' output state. Removes all the intermediate states
sub_question_rewritten_queries: list[str]
sub_question: str
sub_question_nr: int
# Answers sent back to core
sub_qas: Annotated[Sequence[dict], operator.add]
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]

View File

@@ -1,75 +0,0 @@
from collections.abc import Hashable
from typing import Union
from langchain_core.messages import HumanMessage
from langgraph.types import Send
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
def continue_to_initial_sub_questions(
state: QAState,
) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_answers_graph_initial",
BaseQAState(
sub_question_str=initial_sub_question["sub_question_str"],
sub_question_search_queries=initial_sub_question[
"sub_question_search_queries"
],
sub_question_nr=initial_sub_question["sub_question_nr"],
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
graph_start_time=state["graph_start_time"],
),
)
for initial_sub_question in state["initial_sub_questions"]
]
def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_answers_graph",
ResearchQAState(
sub_question=sub_question["sub_question_str"],
sub_question_nr=sub_question["sub_question_nr"],
graph_start_time=state["graph_start_time"],
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
),
)
for sub_question in state["sub_questions"]
]
def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]:
print("---GO TO DEEP ANSWER OR END---")
base_answer = state["base_answer"]
question = state["original_question"]
BASE_CHECK_MESSAGE = [
HumanMessage(
content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer)
)
]
model = state["fast_llm"]
response = model.invoke(BASE_CHECK_MESSAGE)
print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}")
if response.pretty_repr() == "no":
return "decompose"
else:
return "end"

View File

@@ -1,171 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.core_qa_graph.graph_builder import build_core_qa_graph
from danswer.agent_search.deep_qa_graph.graph_builder import build_deep_qa_graph
from danswer.agent_search.primary_graph.edges import continue_to_answer_sub_questions
from danswer.agent_search.primary_graph.edges import continue_to_deep_answer
from danswer.agent_search.primary_graph.edges import continue_to_initial_sub_questions
from danswer.agent_search.primary_graph.nodes.base_wait import base_wait
from danswer.agent_search.primary_graph.nodes.combine_retrieved_docs import (
combine_retrieved_docs,
)
from danswer.agent_search.primary_graph.nodes.custom_retrieve import custom_retrieve
from danswer.agent_search.primary_graph.nodes.decompose import decompose
from danswer.agent_search.primary_graph.nodes.deep_answer_generation import (
deep_answer_generation,
)
from danswer.agent_search.primary_graph.nodes.dummy_start import dummy_start
from danswer.agent_search.primary_graph.nodes.entity_term_extraction import (
entity_term_extraction,
)
from danswer.agent_search.primary_graph.nodes.final_stuff import final_stuff
from danswer.agent_search.primary_graph.nodes.generate_initial import generate_initial
from danswer.agent_search.primary_graph.nodes.main_decomp_base import main_decomp_base
from danswer.agent_search.primary_graph.nodes.rewrite import rewrite
from danswer.agent_search.primary_graph.nodes.sub_qa_level_aggregator import (
sub_qa_level_aggregator,
)
from danswer.agent_search.primary_graph.nodes.sub_qa_manager import sub_qa_manager
from danswer.agent_search.primary_graph.nodes.verifier import verifier
from danswer.agent_search.primary_graph.states import QAState
def build_core_graph() -> StateGraph:
# Define the nodes we will cycle between
core_answer_graph = StateGraph(state_schema=QAState)
### Add Nodes ###
core_answer_graph.add_node(node="dummy_start",
action=dummy_start)
# Re-writing the question
core_answer_graph.add_node(node="rewrite",
action=rewrite)
# The retrieval step
core_answer_graph.add_node(node="custom_retrieve",
action=custom_retrieve)
# Combine and dedupe retrieved docs.
core_answer_graph.add_node(
node="combine_retrieved_docs",
action=combine_retrieved_docs
)
# Extract entities, terms and relationships
core_answer_graph.add_node(
node="entity_term_extraction",
action=entity_term_extraction
)
# Verifying that a retrieved doc is relevant
core_answer_graph.add_node(node="verifier",
action=verifier)
# Initial question decomposition
core_answer_graph.add_node(node="main_decomp_base",
action=main_decomp_base)
# Build the base QA sub-graph and compile it
compiled_core_qa_graph = build_core_qa_graph().compile()
# Add the compiled base QA sub-graph as a node to the core graph
core_answer_graph.add_node(
node="sub_answers_graph_initial",
action=compiled_core_qa_graph
)
# Checking whether the initial answer is in the ballpark
core_answer_graph.add_node(node="base_wait",
action=base_wait)
# Decompose the question into sub-questions
core_answer_graph.add_node(node="decompose",
action=decompose)
# Manage the sub-questions
core_answer_graph.add_node(node="sub_qa_manager",
action=sub_qa_manager)
# Build the research QA sub-graph and compile it
compiled_deep_qa_graph = build_deep_qa_graph().compile()
# Add the compiled research QA sub-graph as a node to the core graph
core_answer_graph.add_node(node="sub_answers_graph",
action=compiled_deep_qa_graph)
# Aggregate the sub-questions
core_answer_graph.add_node(
node="sub_qa_level_aggregator",
action=sub_qa_level_aggregator
)
# aggregate sub questions and answers
core_answer_graph.add_node(
node="deep_answer_generation",
action=deep_answer_generation
)
# A final clean-up step
core_answer_graph.add_node(node="final_stuff",
action=final_stuff)
# Generating a response after we know the documents are relevant
core_answer_graph.add_node(node="generate_initial",
action=generate_initial)
### Add Edges ###
# start the initial sub-question decomposition
core_answer_graph.add_edge(start_key=START,
end_key="main_decomp_base")
core_answer_graph.add_conditional_edges(
source="main_decomp_base",
path=continue_to_initial_sub_questions,
)
# use the retrieved information to generate the answer
core_answer_graph.add_edge(
start_key=["verifier", "sub_answers_graph_initial"],
end_key="generate_initial"
)
core_answer_graph.add_edge(start_key="generate_initial",
end_key="base_wait")
core_answer_graph.add_conditional_edges(
source="base_wait",
path=continue_to_deep_answer,
path_map={"decompose": "entity_term_extraction", "end": "final_stuff"},
)
core_answer_graph.add_edge(start_key="entity_term_extraction", end_key="decompose")
core_answer_graph.add_edge(start_key="decompose",
end_key="sub_qa_manager")
core_answer_graph.add_conditional_edges(
source="sub_qa_manager",
path=continue_to_answer_sub_questions,
)
core_answer_graph.add_edge(
start_key="sub_answers_graph",
end_key="sub_qa_level_aggregator"
)
core_answer_graph.add_edge(
start_key="sub_qa_level_aggregator",
end_key="deep_answer_generation"
)
core_answer_graph.add_edge(
start_key="deep_answer_generation",
end_key="final_stuff"
)
core_answer_graph.add_edge(start_key="final_stuff",
end_key=END)
core_answer_graph.compile()
return core_answer_graph

View File

@@ -1,27 +0,0 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def base_wait(state: QAState) -> dict[str, Any]:
"""
Ensures that all required steps are completed before proceeding to the next step
Args:
state (messages): The current state
Returns:
dict: {} (no operation, just logging)
"""
print("---Base Wait ---")
node_start_time = datetime.now()
return {
"log_messages": generate_log_message(
message="core - base_wait",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,36 +0,0 @@
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
def combine_retrieved_docs(state: QAState) -> dict[str, Any]:
"""
Dedupe the retrieved docs.
"""
node_start_time = datetime.now()
base_retrieval_docs: Sequence[InferenceSection] = state["base_retrieval_docs"]
print(f"Number of docs from steps: {len(base_retrieval_docs)}")
dedupe_docs: list[InferenceSection] = []
for base_retrieval_doc in base_retrieval_docs:
if not any(
base_retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id
for doc in dedupe_docs
):
dedupe_docs.append(base_retrieval_doc)
print(f"Number of deduped docs: {len(dedupe_docs)}")
return {
"deduped_retrieval_docs": dedupe_docs,
"log_messages": generate_log_message(
message="core - combine_retrieved_docs (dedupe)",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,52 +0,0 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
from danswer.db.engine import get_session_context_manager
from danswer.llm.factory import get_default_llms
def custom_retrieve(state: RetrieverState) -> dict[str, Any]:
"""
Retrieve documents
Args:
retriever_state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
node_start_time = datetime.now()
query = state["rewritten_query"]
# Retrieval
# TODO: add the actual retrieval, probably from search_tool.run()
llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
top_sections = SearchPipeline(
search_request=SearchRequest(
query=query,
),
user=None,
llm=llm,
fast_llm=fast_llm,
db_session=db_session,
).reranked_sections
print(len(top_sections))
documents: list[InferenceSection] = []
return {
"base_retrieval_docs": documents,
"log_messages": generate_log_message(
message="core - custom_retrieve",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,78 +0,0 @@
import json
import re
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_entity_term_extraction
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def decompose(state: QAState) -> dict[str, Any]:
""" """
node_start_time = datetime.now()
question = state["original_question"]
base_answer = state["base_answer"]
# get the entity term extraction dict and properly format it
entity_term_extraction_dict = state["retrieved_entities_relationships"][
"retrieved_entities_relationships"
]
entity_term_extraction_str = format_entity_term_extraction(
entity_term_extraction_dict
)
initial_question_answers = state["initial_sub_qas"]
addressed_question_list = [
x["sub_question"]
for x in initial_question_answers
if x["sub_answer_check"] == "yes"
]
failed_question_list = [
x["sub_question"]
for x in initial_question_answers
if x["sub_answer_check"] == "no"
]
msg = [
HumanMessage(
content=DEEP_DECOMPOSE_PROMPT.format(
question=question,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_sub_questions="\n - ".join(addressed_question_list),
failed_sub_questions="\n - ".join(failed_question_list),
),
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr())
parsed_response = json.loads(cleaned_response)
sub_questions_dict = {}
for sub_question_nr, sub_question_dict in enumerate(
parsed_response["sub_questions"]
):
sub_question_dict["answered"] = False
sub_question_dict["verified"] = False
sub_questions_dict[sub_question_nr] = sub_question_dict
return {
"decomposed_sub_questions_dict": sub_questions_dict,
"log_messages": generate_log_message(
message="deep - decompose",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,61 +0,0 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT
from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace
# aggregate sub questions and answers
def deep_answer_generation(state: QAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---DEEP GENERATE---")
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
deep_answer_context = state["core_answer_dynamic_context"]
print(f"Number of verified retrieval docs - deep: {len(docs)}")
combined_context = normalize_whitespace(
COMBINED_CONTEXT.format(
deep_answer_context=deep_answer_context, formated_docs=format_docs(docs)
)
)
msg = [
HumanMessage(
content=MODIFIED_RAG_PROMPT.format(
question=question, combined_context=combined_context
)
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
return {
"deep_answer": response.content,
"log_messages": generate_log_message(
message="deep - deep answer generation",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,11 +0,0 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
def dummy_start(state: QAState) -> dict[str, Any]:
"""
Dummy node to set the start time
"""
return {"start_time": datetime.now()}

View File

@@ -1,51 +0,0 @@
import json
import re
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.primary_graph.prompts import ENTITY_TERM_PROMPT
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def entity_term_extraction(state: QAState) -> dict[str, Any]:
"""Extract entities and terms from the question and context"""
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
doc_context = format_docs(docs)
msg = [
HumanMessage(
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
)
]
_, fast_llm = get_default_llms()
# Grader
llm_response_list = list(
fast_llm.stream(
prompt=msg,
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
# structured_response_format=RewrittenQueries.model_json_schema(),
)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
parsed_response = json.loads(cleaned_response)
return {
"retrieved_entities_relationships": parsed_response,
"log_messages": generate_log_message(
message="deep - entity term extraction",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,85 +0,0 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def final_stuff(state: QAState) -> dict[str, Any]:
"""
Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Args:
state (messages): The current state
Returns:
dict: The updated state with the agent response appended to messages
"""
print("---FINAL---")
node_start_time = datetime.now()
messages = state["log_messages"]
time_ordered_messages = [x.pretty_repr() for x in messages]
time_ordered_messages.sort()
print("Message Log:")
print("\n".join(time_ordered_messages))
initial_sub_qas = state["initial_sub_qas"]
initial_sub_qa_list = []
for initial_sub_qa in initial_sub_qas:
if initial_sub_qa["sub_answer_check"] == "yes":
initial_sub_qa_list.append(
f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----'
)
initial_sub_qa_context = "\n".join(initial_sub_qa_list)
log_message = generate_log_message(
message="all - final_stuff",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
)
print(log_message)
print("--------------------------------")
base_answer = state["base_answer"]
print(f"Final Base Answer:\n{base_answer}")
print("--------------------------------")
print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}")
print("--------------------------------")
if not state.get("deep_answer"):
print("No Deep Answer was required")
return {
"log_messages": log_message,
}
deep_answer = state["deep_answer"]
sub_qas = state["sub_qas"]
sub_qa_list = []
for sub_qa in sub_qas:
if sub_qa["sub_answer_check"] == "yes":
sub_qa_list.append(
f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----'
)
sub_qa_context = "\n".join(sub_qa_list)
print(f"Final Base Answer:\n{base_answer}")
print("--------------------------------")
print(f"Final Deep Answer:\n{deep_answer}")
print("--------------------------------")
print("Sub Questions and Answers:")
print(sub_qa_context)
return {
"log_messages": generate_log_message(
message="all - final_stuff",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,52 +0,0 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def generate(state: QAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---GENERATE---")
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
print(f"Number of verified retrieval docs: {len(docs)}")
msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
)
]
# Grader
llm = state["fast_llm"]
response = list(
llm.stream(
prompt=msg,
structured_response_format=None,
)
)
return {
"base_answer": response[0].pretty_repr(),
"log_messages": generate_log_message(
message="core - generate",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,72 +0,0 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def generate_initial(state: QAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---GENERATE INITIAL---")
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
print(f"Number of verified retrieval docs - base: {len(docs)}")
sub_question_answers = state["initial_sub_qas"]
sub_question_answers_list = []
_SUB_QUESTION_ANSWER_TEMPLATE = """
Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n
"""
for sub_question_answer_dict in sub_question_answers:
if (
sub_question_answer_dict["sub_answer_check"] == "yes"
and len(sub_question_answer_dict["sub_answer"]) > 0
and sub_question_answer_dict["sub_answer"] != "I don't know"
):
sub_question_answers_list.append(
_SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=sub_question_answer_dict["sub_question"],
sub_answer=sub_question_answer_dict["sub_answer"],
)
)
sub_question_answer_str = "\n\n------\n\n".join(sub_question_answers_list)
msg = [
HumanMessage(
content=INITIAL_RAG_PROMPT.format(
question=question,
context=format_docs(docs),
answered_sub_questions=sub_question_answer_str,
)
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
return {
"base_answer": response.pretty_repr(),
"log_messages": generate_log_message(
message="core - generate initial",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,64 +0,0 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.prompts import INITIAL_DECOMPOSITION_PROMPT
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def main_decomp_base(state: QAState) -> dict[str, Any]:
"""
Perform an initial question decomposition, incl. one search term
Args:
state (messages): The current state
Returns:
dict: The updated state with initial decomposition
"""
print("---INITIAL DECOMP---")
node_start_time = datetime.now()
question = state["original_question"]
msg = [
HumanMessage(
content=INITIAL_DECOMPOSITION_PROMPT.format(question=question),
)
]
# Get the rewritten queries in a defined format
model = state["fast_llm"]
response = model.invoke(msg)
content = response.pretty_repr()
list_of_subquestions = clean_and_parse_list_string(content)
decomp_list = []
for sub_question_nr, sub_question in enumerate(list_of_subquestions):
sub_question_str = sub_question["sub_question"].strip()
# temporarily
sub_question_search_queries = [sub_question["search_term"]]
decomp_list.append(
{
"sub_question_str": sub_question_str,
"sub_question_search_queries": sub_question_search_queries,
"sub_question_nr": sub_question_nr,
}
)
return {
"initial_sub_questions": decomp_list,
"sub_query_start_time": node_start_time,
"log_messages": generate_log_message(
message="core - initial decomp",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,55 +0,0 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def rewrite(state: QAState) -> dict[str, Any]:
"""
Transform the initial question into more suitable search queries.
Args:
qa_state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---STARTING GRAPH---")
graph_start_time = datetime.now()
print("---TRANSFORM QUERY---")
node_start_time = datetime.now()
question = state["original_question"]
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI.format(question=question),
)
]
# Get the rewritten queries in a defined format
fast_llm = state["fast_llm"]
llm_response = list(
fast_llm.stream(
prompt=msg,
structured_response_format=RewrittenQueries.model_json_schema(),
)
)
formatted_response: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
return {
"rewritten_queries": formatted_response.rewritten_queries,
"log_messages": generate_log_message(
message="core - rewrite",
node_start_time=node_start_time,
graph_start_time=graph_start_time,
),
}

View File

@@ -1,39 +0,0 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
# aggregate sub questions and answers
def sub_qa_level_aggregator(state: QAState) -> dict[str, Any]:
sub_qas = state["sub_qas"]
node_start_time = datetime.now()
dynamic_context_list = [
"Below you will find useful information to answer the original question:"
]
checked_sub_qas = []
for core_answer_sub_qa in sub_qas:
question = core_answer_sub_qa["sub_question"]
answer = core_answer_sub_qa["sub_answer"]
verified = core_answer_sub_qa["sub_answer_check"]
if verified == "yes":
dynamic_context_list.append(
f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n"
)
checked_sub_qas.append({"sub_question": question, "sub_answer": answer})
dynamic_context = "\n".join(dynamic_context_list)
return {
"core_answer_dynamic_context": dynamic_context,
"checked_sub_qas": checked_sub_qas,
"log_messages": generate_log_message(
message="deep - sub qa level aggregator",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,28 +0,0 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_qa_manager(state: QAState) -> dict[str, Any]:
""" """
node_start_time = datetime.now()
sub_questions_dict = state["decomposed_sub_questions_dict"]
sub_questions = {}
for sub_question_nr, sub_question_dict in sub_questions_dict.items():
sub_questions[sub_question_nr] = sub_question_dict["sub_question"]
return {
"sub_questions": sub_questions,
"num_new_question_iterations": 0,
"log_messages": generate_log_message(
message="deep - sub qa manager",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,59 +0,0 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import VerifierState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def verifier(state: VerifierState) -> dict[str, Any]:
"""
Check whether the document is relevant for the original user question
Args:
state (VerifierState): The current state
Returns:
dict: ict: The updated state with the final decision
"""
print("---VERIFY QUTPUT---")
node_start_time = datetime.now()
question = state["question"]
document_content = state["document"].combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
# Grader
llm = state["fast_llm"]
response = list(
llm.stream(
prompt=msg,
structured_response_format=BinaryDecision.model_json_schema(),
)
)
raw_response = json.loads(response[0].pretty_repr())
formatted_response = BinaryDecision.model_validate(raw_response)
return {
"deduped_retrieval_docs": [state["document"]]
if formatted_response.decision == "yes"
else [],
"log_messages": generate_log_message(
message=f"core - verifier: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,86 +0,0 @@
INITIAL_DECOMPOSITION_PROMPT = """ \n
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
For each sub-question, please also create one search term that can be used to retrieve relevant
documents from a document store.
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of json objects with the following format:
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
Answer:
"""
INITIAL_RAG_PROMPT = """ \n
You are an assistant for question-answering tasks. Use the information provided below - and only the
provided information - to answer the provided question.
The information provided below consists of:
1) a number of answered sub-questions - these are very important(!) and definitely should be
considered to answer the question.
2) a number of documents that were also deemed relevant for the question.
If you don't know the answer or if the provided information is empty or insufficient, just say
"I don't know". Do not use your internal knowledge!
Again, only use the provided informationand do not use your internal knowledge! It is a matter of life
and death that you do NOT use your internal knowledge, just the provided information!
Try to keep your answer concise.
And here is the question and the provided information:
\n
\nQuestion:\n {question}
\nAnswered Sub-questions:\n {answered_sub_questions}
\nContext:\n {context} \n\n
\n\n
Answer:"""
ENTITY_TERM_PROMPT = """ \n
Based on the original question and the context retieved from a dataset, please generate a list of
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
\n\n
Here is the original question:
\n ------- \n
{question}
\n ------- \n
And here is the context retrieved:
\n ------- \n
{context}
\n ------- \n
Please format your answer as a json object in the following format:
{{"retrieved_entities_relationships": {{
"entities": [{{
"entity_name": <assign a name for the entity>,
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
}}],
"relationships": [{{
"name": <assign a name for the relationship>,
"type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
"entities": [<related entity name 1>, <related entity name 2>]
}}],
"terms": [{{
"term_name": <assign a name for the term>,
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
"similar_to": <list terms that are similar to this term>
}}]
}}
}}
"""

View File

@@ -1,73 +0,0 @@
import operator
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated
from typing import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.context.search.models import InferenceSection
class QAState(TypedDict):
# The 'main' state of the answer graph
original_question: str
graph_start_time: datetime
# start time for parallel initial sub-questionn thread
sub_query_start_time: datetime
log_messages: Annotated[Sequence[BaseMessage], add_messages]
rewritten_queries: RewrittenQueries
sub_questions: list[dict]
initial_sub_questions: list[dict]
ranked_subquestion_ids: list[int]
decomposed_sub_questions_dict: dict
rejected_sub_questions: Annotated[list[str], operator.add]
rejected_sub_questions_handled: bool
sub_qas: Annotated[Sequence[dict], operator.add]
initial_sub_qas: Annotated[Sequence[dict], operator.add]
checked_sub_qas: Annotated[Sequence[dict], operator.add]
base_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
deduped_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
retrieved_entities_relationships: dict
questions_context: list[dict]
qa_level: int
top_chunks: list[InferenceSection]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
num_new_question_iterations: int
core_answer_dynamic_context: str
dynamic_context: str
initial_base_answer: str
base_answer: str
deep_answer: str
class QAOuputState(TypedDict):
# The 'main' output state of the answer graph. Removes all the intermediate states
original_question: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]
sub_questions: list[dict]
sub_qas: Annotated[Sequence[dict], operator.add]
initial_sub_qas: Annotated[Sequence[dict], operator.add]
checked_sub_qas: Annotated[Sequence[dict], operator.add]
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
retrieved_entities_relationships: dict
top_chunks: list[InferenceSection]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
base_answer: str
deep_answer: str
class RetrieverState(TypedDict):
# The state for the parallel Retrievers. They each need to see only one query
rewritten_query: str
graph_start_time: datetime
class VerifierState(TypedDict):
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
document: InferenceSection
question: str
graph_start_time: datetime

View File

@@ -1,22 +0,0 @@
from danswer.agent_search.primary_graph.graph_builder import build_core_graph
from danswer.llm.answering.answer import AnswerStream
from danswer.llm.interfaces import LLM
from danswer.tools.tool import Tool
def run_graph(
query: str,
llm: LLM,
tools: list[Tool],
) -> AnswerStream:
graph = build_core_graph()
inputs = {
"original_question": query,
"messages": [],
"tools": tools,
"llm": llm,
}
compiled_graph = graph.compile()
output = compiled_graph.invoke(input=inputs)
yield from output

View File

@@ -1,16 +0,0 @@
from typing import Literal
from pydantic import BaseModel
# Pydantic models for structured outputs
class RewrittenQueries(BaseModel):
rewritten_queries: list[str]
class BinaryDecision(BaseModel):
decision: Literal["yes", "no"]
class SubQuestions(BaseModel):
sub_questions: list[str]

View File

@@ -1,342 +0,0 @@
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
enabling the system to search more broadly.
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """
REWRITE_PROMPT_MULTI = """ \n
Please create a list of 2-3 sample documents that could answer an original question. Each document
should be about as long as the original question. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
BASE_RAG_PROMPT = """ \n
You are an assistant for question-answering tasks. Use the context provided below - and only the
provided context - to answer the question. If you don't know the answer or if the provided context is
empty, just say "I don't know". Do not use your internal knowledge!
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
Use three sentences maximum and keep the answer concise.
answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n
\n\n
Answer:"""
BASE_CHECK_PROMPT = """ \n
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
original question requests a simple, factual answer, and there are no ambiguities, judgements,
aggregations, or any other complications that may require extra context. (I.e., if the question is
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
Please only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the proposed answer:
\n ------- \n
{base_answer}
\n ------- \n
Please answer with yes or no:"""
VERIFIER_PROMPT = """ \n
Please check whether the document seems to be relevant for the answer of the original question. Please
only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the document text:
\n ------- \n
{document_content}
\n ------- \n
Please answer with yes or no:"""
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of subquestions:
Answer:
"""
REWRITE_PROMPT_SINGLE = """ \n
Please convert an initial user question into a more appropriate search query for retrievel from a
document store. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the query: """
MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below
- and only this context - to answer the question. If you don't know the answer, just say "I don't know".
Use three sentences maximum and keep the answer concise.
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
\nQuestion: {question}
\nContext: {combined_context} \n
Answer:"""
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
sub-question using as a search phrase for the document store>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 4-6 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please also provide a search term that can be used to retrieve relevant
documents from a document store.
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Generate the list of json dictionaries with the following format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
DECOMPOSE_PROMPT = """ \n
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question.
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
document store, expressed as entities, relationships and terms. You can also think about the types
mentioned in brackets
Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
and or resolve ambiguities
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
sub-question using as a search phrase for the document store>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Don't be too specific unless the original question is specific.
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
#### Consolidations
COMBINED_CONTEXT = """-------
Below you will find useful information to answer the original question. First, you see a number of
sub-questions with their answers. This information should be considered to be more focussed and
somewhat more specific to the original question as it tries to contextualized facts.
After that will see the documents that were considered to be relevant to answer the original question.
Here are the sub-questions and their answers:
\n\n {deep_answer_context} \n\n
\n\n Here are the documents that were considered to be relevant to answer the original question:
\n\n {formated_docs} \n\n
----------------
"""
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
Below you will find a question that we ultimately want to answer (the original question) and a list of
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
motivation and 2 is less relevant.)
Please rank the motivations in order of relevance for answering the original question. Also, try to
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
Ultimately, create a list with the motivation numbers where the number of the most relevant
motivations comes first.
Here is the original question:
\n\n {original_question} \n\n
\n\n Here is the list of sub-question motivations:
\n\n {sub_question_explanations} \n\n
----------------
Please think step by step and then generate the ranked list of motivations.
Please format your answer as a json object in the following format:
{{"reasonning": <explain your reasoning for the ranking>,
"ranked_motivations": <ranked list of motivation numbers>}}
"""

View File

@@ -1,91 +0,0 @@
import ast
import json
import re
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from typing import Any
from danswer.context.search.models import InferenceSection
def normalize_whitespace(text: str) -> str:
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
import re
return re.sub(r"\s+", " ", text.strip())
# Post-processing
def format_docs(docs: Sequence[InferenceSection]) -> str:
return "\n\n".join(doc.combined_content for doc in docs)
def clean_and_parse_list_string(json_string: str) -> list[dict]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return ast.literal_eval(cleaned_string)
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return json.loads(cleaned_string)
def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str:
entities = entity_term_extraction_dict["entities"]
terms = entity_term_extraction_dict["terms"]
relationships = entity_term_extraction_dict["relationships"]
entity_strs = ["\nEntities:\n"]
for entity in entities:
entity_str = f"{entity['entity_name']} ({entity['entity_type']})"
entity_strs.append(entity_str)
entity_str = "\n - ".join(entity_strs)
relationship_strs = ["\n\nRelationships:\n"]
for relationship in relationships:
relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}"
relationship_strs.append(relationship_str)
relationship_str = "\n - ".join(relationship_strs)
term_strs = ["\n\nTerms:\n"]
for term in terms:
term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}"
term_strs.append(term_str)
term_str = "\n - ".join(term_strs)
return "\n".join(entity_strs + relationship_strs + term_strs)
def _format_time_delta(time: timedelta) -> str:
seconds_from_start = f"{((time).seconds):03d}"
microseconds_from_start = f"{((time).microseconds):06d}"
return f"{seconds_from_start}.{microseconds_from_start}"
def generate_log_message(
message: str,
node_start_time: datetime,
graph_start_time: datetime | None = None,
) -> str:
current_time = datetime.now()
if graph_start_time is not None:
graph_time_str = _format_time_delta(current_time - graph_start_time)
else:
graph_time_str = "N/A"
node_time_str = _format_time_delta(current_time - node_start_time)
return f"{graph_time_str} ({node_time_str} s): {message}"

View File

@@ -1,89 +0,0 @@
import secrets
import uuid
from urllib.parse import quote
from urllib.parse import unquote
from fastapi import Request
from passlib.hash import sha256_crypt
from pydantic import BaseModel
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import API_KEY_HASH_ROUNDS
_API_KEY_HEADER_NAME = "Authorization"
# NOTE for others who are curious: In the context of a header, "X-" often refers
# to non-standard, experimental, or custom headers in HTTP or other protocols. It
# indicates that the header is not part of the official standards defined by
# organizations like the Internet Engineering Task Force (IETF).
_API_KEY_HEADER_ALTERNATIVE_NAME = "X-Danswer-Authorization"
_BEARER_PREFIX = "Bearer "
_API_KEY_PREFIX = "dn_"
_API_KEY_LEN = 192
class ApiKeyDescriptor(BaseModel):
api_key_id: int
api_key_display: str
api_key: str | None = None # only present on initial creation
api_key_name: str | None = None
api_key_role: UserRole
user_id: uuid.UUID
def generate_api_key(tenant_id: str | None = None) -> str:
# For backwards compatibility, if no tenant_id, generate old style key
if not tenant_id:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}"
def extract_tenant_from_api_key_header(request: Request) -> str | None:
"""Extract tenant ID from request. Returns None if auth is disabled or invalid format."""
raw_api_key_header = request.headers.get(
_API_KEY_HEADER_ALTERNATIVE_NAME
) or request.headers.get(_API_KEY_HEADER_NAME)
if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX):
return None
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
if not api_key.startswith(_API_KEY_PREFIX):
return None
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
if len(parts) != 2:
return None
tenant_id = parts[0]
return unquote(tenant_id) if tenant_id else None
def hash_api_key(api_key: str) -> str:
# NOTE: no salt is needed, as the API key is randomly generated
# and overlaps are impossible
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
def build_displayable_api_key(api_key: str) -> str:
if api_key.startswith(_API_KEY_PREFIX):
api_key = api_key[len(_API_KEY_PREFIX) :]
return _API_KEY_PREFIX + api_key[:4] + "********" + api_key[-4:]
def get_hashed_api_key_from_request(request: Request) -> str | None:
raw_api_key_header = request.headers.get(
_API_KEY_HEADER_ALTERNATIVE_NAME
) or request.headers.get(_API_KEY_HEADER_NAME)
if raw_api_key_header is None:
return None
if raw_api_key_header.startswith(_BEARER_PREFIX):
raw_api_key_header = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
return hash_api_key(raw_api_key_header)

View File

@@ -2,14 +2,13 @@ from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.utils.special_types import JSON_ro
def get_invited_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_USER_STORE_KEY))
except KvKeyNotFoundError:
return list()

View File

@@ -23,9 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(
chosen_assistants=None, default_model=None, auto_scroll=True
)
return UserPreferences(chosen_assistants=None, default_model=None)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:

View File

@@ -13,24 +13,12 @@ class UserRole(str, Enum):
groups they are curators of
- Global Curator can perform admin actions
for all groups they are a member of
- Limited can access a limited set of basic api endpoints
- Slack are users that have used danswer via slack but dont have a web login
- External permissioned users that have been picked up during the external permissions sync process but don't have a web login
"""
LIMITED = "limited"
BASIC = "basic"
ADMIN = "admin"
CURATOR = "curator"
GLOBAL_CURATOR = "global_curator"
SLACK_USER = "slack_user"
EXT_PERM_USER = "ext_perm_user"
def is_web_login(self) -> bool:
return self not in [
UserRole.SLACK_USER,
UserRole.EXT_PERM_USER,
]
class UserStatus(str, Enum):
@@ -45,8 +33,10 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
tenant_id: str | None = None
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
has_web_login: bool | None = True

View File

@@ -48,19 +48,20 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import DISABLE_VERIFICATION
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SMTP_PASS
from danswer.configs.app_configs import SMTP_PORT
@@ -74,28 +75,25 @@ from danswer.configs.constants import AuthType
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.server.utils import BasicAuthenticationError
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@@ -134,9 +132,7 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
def user_needs_to_be_verified() -> bool:
# all other auth types besides basic should require users to be
# verified
return not DISABLE_VERIFICATION and (
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
)
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
def verify_email_is_invited(email: str) -> None:
@@ -189,6 +185,20 @@ def verify_email_domain(email: str) -> None:
)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return "public"
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def send_user_verification_email(
user_email: str,
token: str,
@@ -217,36 +227,31 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
async def create(
self,
user_create: schemas.UC | UserCreate,
safe: bool = False,
request: Optional[Request] = None,
) -> User:
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)
try:
tenant_id = (
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
)
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
token = current_tenant_id.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
)
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
@@ -259,15 +264,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login() and user_create.role.is_web_login():
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
@@ -275,13 +285,34 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else:
raise exceptions.UserAlreadyExists()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
current_tenant_id.reset(token)
return user
async def oauth_callback(
async def on_after_login(
self,
user: User,
request: Request | None = None,
response: Response | None = None,
) -> None:
if response is None or not MULTI_TENANT:
return
tenant_id = get_tenant_id_for_email(user.email)
tenant_token = jwt.encode(
{"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256"
)
response.set_cookie(
key="tenant_details",
value=tenant_token,
httponly=True,
secure=WEB_DOMAIN.startswith("https"),
samesite="lax",
)
async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
oauth_name: str,
access_token: str,
account_id: str,
@@ -292,37 +323,28 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
*,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> User:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
)
) -> models.UOAP:
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
# Proceed with the tenant context
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
token = current_tenant_id.set(tenant_id)
verify_email_in_whitelist(account_email, tenant_id)
verify_email_domain(account_email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
)
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
self.database = tenant_user_db # type: ignore
oauth_account_dict = {
"oauth_name": oauth_name,
@@ -358,13 +380,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
}
user = await self.user_db.create(user_dict)
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
await self.on_after_register(user, request)
else:
@@ -374,11 +392,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
user, existing_oauth_account, oauth_account_dict
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
@@ -391,15 +405,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login():
if not user.has_web_login: # type: ignore
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"role": UserRole.BASIC,
"has_web_login": True,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True # type: ignore
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
@@ -411,7 +426,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user.oidc_expiry = None # type: ignore
if token:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
current_tenant_id.reset(token)
return user
@@ -447,13 +462,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
email = credentials.username
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=email,
)
tenant_id = get_tenant_id_for_email(email)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -474,8 +483,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self.password_helper.hash(credentials.password)
return None
if not user.role.is_web_login():
raise BasicAuthenticationError(
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -505,33 +517,8 @@ cookie_transport = CookieTransport(
)
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user.email,
)
data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return data
async def write_token(self, user: User) -> str:
data = await self._create_token_data(user)
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
)
def get_jwt_strategy() -> TenantAwareJWTStrategy:
return TenantAwareJWTStrategy(
def get_jwt_strategy() -> JWTStrategy:
return JWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
)
@@ -605,7 +592,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional
async def optional_user_(
request: Request,
user: User | None,
async_db_session: AsyncSession,
db_session: Session,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
@@ -614,21 +601,13 @@ async def optional_user_(
async def optional_user(
request: Request,
async_db_session: AsyncSession = Depends(get_async_session),
db_session: Session = Depends(get_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
user = await versioned_fetch_user(request, user, async_db_session)
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
return user
return await versioned_fetch_user(request, user, db_session)
async def double_check_user(
@@ -640,12 +619,14 @@ async def double_check_user(
return None
if user is None:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated.",
)
if user_needs_to_be_verified() and not user.is_verified:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not verified.",
)
@@ -654,7 +635,8 @@ async def double_check_user(
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
)
@@ -667,24 +649,10 @@ async def current_user_with_expired_token(
return await double_check_user(user, include_expired=True)
async def current_limited_user(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:
user = await double_check_user(user)
if not user:
return None
if user.role == UserRole.LIMITED:
raise BasicAuthenticationError(
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
)
return user
return await double_check_user(user)
async def current_curator_or_admin_user(
@@ -694,13 +662,15 @@ async def current_curator_or_admin_user(
return None
if not user or not hasattr(user, "role"):
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated or lacks role information.",
)
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
if user.role not in allowed_roles:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not a curator or admin.",
)
@@ -712,7 +682,8 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
return None
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User must be an admin to perform this action.",
)
@@ -740,6 +711,8 @@ def generate_state_token(
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
@@ -789,22 +762,15 @@ def get_oauth_router(
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request,
scopes: List[str] = Query(None),
request: Request, scopes: List[str] = Query(None)
) -> OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
}
state_data: Dict[str, str] = {"next_url": next_url}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
@@ -863,11 +829,8 @@ def get_oauth_router(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
request.state.referral_source = referral_source
# Proceed to authenticate or create the user
# Authenticate user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
@@ -909,25 +872,7 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router
async def api_key_dep(
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return user

View File

@@ -1,403 +0,0 @@
import logging
import multiprocessing
import time
from typing import Any
import requests
import sentry_sdk
from celery import Task
from celery.app import trace
from celery.exceptions import WorkerShutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from redis.lock import Lock as RedisLock
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
task_logger = get_task_logger(__name__)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
pass
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict[str, Any] | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
"""We handle this signal in order to remove completed tasks
from their respective tasksets. This allows us to track the progress of document set
and user group syncs.
This function runs after any task completes (both success and failure)
Note that this signal does not fire on a task that failed to complete and is going
to be retried.
This also does not fire if a worker with acks_late=False crashes (which all of our
long running workers are)
"""
if not task:
return
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
if state not in READY_STATES:
return
if not task_id:
return
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
else:
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)
r = get_redis_client(tenant_id=tenant_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
return
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(tenant_id, int(document_set_id))
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(tenant_id, int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDelete.PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r)
return
if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
return
if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorPermissionSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
return
if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorExternalGroupSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
return
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
r = get_redis_client(tenant_id=None)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Redis: Readiness probe starting.")
while True:
try:
if r.ping():
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Redis: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Redis: Readiness probe succeeded. Continuing...")
return
def wait_for_db(sender: Any, **kwargs: Any) -> None:
"""Waits for the db to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Database: Readiness probe starting.")
while True:
try:
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(text("SELECT NOW()")).scalar()
if result:
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Database: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Database: Readiness probe succeeded. Continuing...")
return
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Vespa: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_redis_client(tenant_id=None)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")
while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time_elapsed = time.monotonic() - time_start
logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
return
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
task_logger.info("worker_ready signal received.")
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not sender.primary_worker_lock:
return
logger.info("Releasing primary worker lock.")
lock: RedisLock = sender.primary_worker_lock
try:
if lock.owned():
try:
lock.release()
sender.primary_worker_lock = None
except Exception:
logger.exception("Failed to release primary worker lock")
except Exception:
logger.exception("Failed to check if primary worker lock is owned")
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config
# reformats the root logger
root_logger = logging.getLogger()
root_handler = logging.StreamHandler() # Set up a handler for the root logger
root_formatter = ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_handler.setFormatter(root_formatter)
root_logger.addHandler(root_handler) # Apply the handler to the root logger
if logfile:
root_file_handler = logging.FileHandler(logfile)
root_file_formatter = PlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_file_handler.setFormatter(root_file_formatter)
root_logger.addHandler(root_file_handler)
root_logger.setLevel(loglevel)
# reformats celery's task logger
task_formatter = CeleryTaskColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_handler = logging.StreamHandler() # Set up a handler for the task logger
task_handler.setFormatter(task_formatter)
task_logger.addHandler(task_handler) # Apply the handler to the task logger
if logfile:
task_file_handler = logging.FileHandler(logfile)
task_file_formatter = CeleryTaskPlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_file_handler.setFormatter(task_file_formatter)
task_logger.addHandler(task_file_handler)
task_logger.setLevel(loglevel)
task_logger.propagate = False
# hide celery task received spam
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
strategy.logger.setLevel(logging.WARNING)
# hide celery task succeeded/failed spam
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
trace.logger.setLevel(logging.WARNING)

View File

@@ -1,172 +0,0 @@
from datetime import timedelta
from typing import Any
from celery import Celery
from celery import signals
from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.beat")
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.info("Initializing DynamicTenantScheduler")
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=2)
self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization
self.setup_schedule()
self._update_tenant_tasks()
logger.info(f"Set reload interval to {self._reload_interval}")
def setup_schedule(self) -> None:
logger.info("Setting up initial schedule")
super().setup_schedule()
logger.info("Initial schedule setup complete")
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating tenant task update")
self._update_tenant_tasks()
self._last_reload = now
logger.info("Tenant task update completed, reset reload timer")
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Starting tenant task update process")
try:
logger.info("Fetching all tenant IDs")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} tenants")
logger.info("Fetching tasks to schedule")
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = self.schedule.items()
existing_tenants = set()
for task_name, _ in current_schedule:
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
for tenant_id in tenant_ids:
if (
IGNORED_SYNCING_TENANT_LIST
and tenant_id in IGNORED_SYNCING_TENANT_LIST
):
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {task_name}")
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
# Create schedule entries
entries = {}
for name, entry in new_beat_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
)
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
# Ensure changes are persisted
self.sync()
logger.info("Schedule update completed successfully")
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
logger.debug("Comparing current and new schedules")
current_tasks = set(name for name, _ in current_schedule)
new_tasks = set(new_schedule.keys())
needs_update = current_tasks != new_tasks
logger.debug(f"Schedule update needed: {needs_update}")
return needs_update
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
logger.info("beat_init signal received.")
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.conf.beat_scheduler = DynamicTenantScheduler

View File

@@ -1,97 +0,0 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.heavy")
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.doc_permission_syncing",
"danswer.background.celery.tasks.external_group_syncing",
]
)

View File

@@ -1,101 +0,0 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.indexing")
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.indexing",
]
)

View File

@@ -1,97 +0,0 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.light")
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.doc_permission_syncing",
]
)

Some files were not shown because too many files have changed in this diff Show More