Compare commits

..

1 Commits

1037 changed files with 25955 additions and 88617 deletions

View File

@@ -1,76 +0,0 @@
name: 'Build and Push Docker Image with Retry'
description: 'Attempts to build and push a Docker image, with a retry on failure'
inputs:
context:
description: 'Build context'
required: true
file:
description: 'Dockerfile location'
required: true
platforms:
description: 'Target platforms'
required: true
pull:
description: 'Always attempt to pull a newer version of the image'
required: false
default: 'true'
push:
description: 'Push the image to registry'
required: false
default: 'true'
load:
description: 'Load the image into Docker daemon'
required: false
default: 'true'
tags:
description: 'Image tags'
required: true
cache-from:
description: 'Cache sources'
required: false
cache-to:
description: 'Cache destinations'
required: false
retry-wait-time:
description: 'Time to wait before retry in seconds'
required: false
default: '5'
runs:
using: "composite"
steps:
- name: Build and push Docker image (First Attempt)
id: buildx1
uses: docker/build-push-action@v5
continue-on-error: true
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Wait to retry
if: steps.buildx1.outcome != 'success'
run: |
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
sleep ${{ inputs.retry-wait-time }}
shell: bash
- name: Build and push Docker image (Retry Attempt)
if: steps.buildx1.outcome != 'success'
uses: docker/build-push-action@v5
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}

View File

@@ -1,25 +0,0 @@
## Description
[Provide a brief description of the changes in this PR]
## How Has This Been Tested?
[Describe the tests you ran to verify your changes]
## Accepted Risk
[Any know risks or failure modes to point out to reviewers]
## Related Issue(s)
[If applicable, link to the issue(s) this PR addresses]
## Checklist:
- [ ] All of the automated tests pass
- [ ] All PR comments are addressed and marked resolved
- [ ] If there are migrations, they have been rebased to latest main
- [ ] If there are new dependencies, they are added to the requirements
- [ ] If there are new environment variables, they are added to all of the deployment methods
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- [ ] Docker images build and basic functionalities work
- [ ] Author has done a final read through of the PR right before merge

View File

@@ -1,23 +0,0 @@
name: Check Backend Changes
on:
workflow_call:
outputs:
run-tests:
description: "Whether to run tests based on backend changes"
value: ${{ jobs.check-run-needed.outputs.run-tests }}
jobs:
check-run-needed:
runs-on: ubuntu-latest
outputs:
run-tests: ${{ steps.check.outputs.run-tests }}
steps:
- uses: actions/checkout@v4
- id: check
run: |
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} | grep -q '^backend/'; then
echo "run-tests=true" >> $GITHUB_OUTPUT
else
echo "run-tests=false" >> $GITHUB_OUTPUT
fi

View File

@@ -5,38 +5,33 @@ on:
tags:
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-backend
jobs:
build-and-push:
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
uses: docker/build-push-action@v2
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.REGISTRY_IMAGE }}:latest
danswer/danswer-backend:${{ github.ref_name }}
danswer/danswer-backend:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
@@ -44,6 +39,6 @@ jobs:
uses: aquasecurity/trivy-action@master
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
trivyignores: ./backend/.trivyignore

View File

@@ -7,24 +7,23 @@ on:
jobs:
build-and-push:
runs-on:
group: amd64-image-builders
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
uses: docker/build-push-action@v2
with:
context: ./backend
file: ./backend/Dockerfile.model_server

View File

@@ -5,115 +5,38 @@ on:
tags:
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-web-server
jobs:
build:
runs-on:
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: ${{ matrix.platform }}
push: true
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
merge:
build-and-push:
runs-on: ubuntu-latest
needs:
- build
steps:
- name: Download digests
uses: actions/download-artifact@v4
with:
path: /tmp/digests
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Web Image Docker Build and Push
uses: docker/build-push-action@v2
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
danswer/danswer-web-server:${{ github.ref_name }}
danswer/danswer-web-server:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: docker.io/danswer/danswer-web-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -1,67 +0,0 @@
# This workflow is intentionally disabled while we're still working on it
# It's close to ready, but a race condition needs to be fixed with
# API server and Vespa startup, and it needs to have a way to build/test against
# local containers
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
jobs:
lint-test:
runs-on: Amd64
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/model_server.txt
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
- name: Run chart-testing (list-changed)
id: list-changed
run: |
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
- name: Run chart-testing (lint)
# if: steps.list-changed.outputs.changed == 'true'
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
# if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
# if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --config ct.yaml
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -1,17 +1,11 @@
name: Python Checks
on:
merge_group:
pull_request:
branches: [ main ]
jobs:
check-changes:
uses: ./.github/workflows/check-backend-changes.yml
mypy-check:
needs: check-changes
if: needs.check-changes.outputs.run-tests == 'true'
runs-on: ubuntu-latest
steps:
@@ -52,10 +46,3 @@ jobs:
run: |
cd backend
black --check .
skip-tests:
needs: check-changes
if: needs.check-changes.outputs.run-tests == 'false'
runs-on: ubuntu-latest
steps:
- run: echo "No changes in backend, skipping this test."

View File

@@ -1,69 +0,0 @@
name: Connector Tests
on:
pull_request:
branches: [main]
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
env:
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
check-changes:
uses: ./.github/workflows/check-backend-changes.yml
connectors-check:
needs: check-changes
if: needs.check-changes.outputs.run-tests == 'true'
runs-on: ubuntu-latest
env:
PYTHONPATH: ./backend
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK
skip-tests:
needs: check-changes
if: needs.check-changes.outputs.run-tests == 'false'
runs-on: ubuntu-latest
steps:
- run: echo "No changes in backend, skipping this test."

View File

@@ -1,17 +1,11 @@
name: Python Unit Tests
on:
merge_group:
pull_request:
branches: [ main ]
jobs:
check-changes:
uses: ./.github/workflows/check-backend-changes.yml
backend-check:
needs: check-changes
if: needs.check-changes.outputs.run-tests == 'true'
runs-on: ubuntu-latest
env:
@@ -39,11 +33,3 @@ jobs:
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/unit
skip-tests:
needs: check-changes
if: needs.check-changes.outputs.run-tests == 'false'
runs-on: ubuntu-latest
steps:
- run: echo "No changes in backend, skipping this test."

View File

@@ -4,19 +4,18 @@ concurrency:
cancel-in-progress: true
on:
merge_group:
pull_request: null
jobs:
quality-checks:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.0
with:
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- uses: pre-commit/action@v3.0.0
with:
extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }}

View File

@@ -1,172 +0,0 @@
name: Run Integration Tests
concurrency:
group: Run-Integration-Tests-${{ github.head_ref }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches: [ main ]
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
check-changes:
uses: ./.github/workflows/check-backend-changes.yml
integration-tests:
needs: check-changes
if: needs.check-changes.outputs.run-tests == 'true'
runs-on:
group: 'arm64-image-builders'
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# NOTE: we don't need to build the Web Docker image since it's not used
# during the IT for now. We have a separate action to verify it builds
# succesfully
- name: Pull Web Docker image
run: |
docker pull danswer/danswer-web-server:latest
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/arm64
tags: danswer/danswer-backend:it
cache-from: type=registry,ref=danswer/danswer-backend:it
cache-to: |
type=registry,ref=danswer/danswer-backend:it,mode=max
type=inline
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/arm64
tags: danswer/danswer-model-server:it
cache-from: type=registry,ref=danswer/danswer-model-server:it
cache-to: |
type=registry,ref=danswer/danswer-model-server:it,mode=max
type=inline
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/arm64
tags: danswer/integration-test-runner:it
cache-from: type=registry,ref=danswer/integration-test-runner:it
cache-to: |
type=registry,ref=danswer/integration-test-runner:it,mode=max
type=inline
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=it \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run integration tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
danswer/integration-test-runner:it
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v3
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
skip-tests:
needs: check-changes
if: needs.check-changes.outputs.run-tests == 'false'
runs-on: ubuntu-latest
steps:
- run: echo "No changes in backend, skipping this test."

4
.gitignore vendored
View File

@@ -4,6 +4,4 @@
.mypy_cache
.idea
/deployment/data/nginx/app.conf
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
.vscode/launch.json

View File

@@ -1,51 +0,0 @@
# Copy this file to .env in the .vscode folder
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
# For local dev, often user Authentication is not needed
AUTH_TYPE=disabled
# Always keep these on for Dev
# Logs all model prompts to stdout
LOG_DANSWER_MODEL_INTERACTIONS=True
# More verbose logging
LOG_LEVEL=debug
# This passes top N results to LLM an additional time for reranking prior to answer generation
# This step is quite heavy on token usage so we disable it for dev generally
DISABLE_LLM_DOC_RELEVANCE=False
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
OAUTH_CLIENT_ID=<REPLACE THIS>
OAUTH_CLIENT_SECRET=<REPLACE THIS>
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
REQUIRE_EMAIL_VERIFICATION=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
GEN_AI_MODEL_VERSION=gpt-4o
FAST_GEN_AI_MODEL_VERSION=gpt-4o
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
# Only needed if using DanswerBot
#DANSWER_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
#DANSWER_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
# Python stuff
PYTHONPATH=../backend
PYTHONUNBUFFERED=1
# Internet Search
BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False

View File

@@ -1,23 +1,15 @@
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
/*
Copy this file into '.vscode/launch.json' or merge its
contents into your existing configurations.
*/
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
"name": "Run All Danswer Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Indexing",
"Background Jobs",
"Slack Bot"
]
}
],
"configurations": [
{
"name": "Web Server",
@@ -25,7 +17,6 @@
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": [
"run", "dev"
],
@@ -33,12 +24,10 @@
},
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"type": "python",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
@@ -52,14 +41,12 @@
},
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"type": "python",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_ALL_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
@@ -72,14 +59,12 @@
},
{
"name": "Indexing",
"consoleName": "Indexing",
"type": "debugpy",
"type": "python",
"request": "launch",
"program": "danswer/background/update.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"ENABLE_MINI_CHUNK": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
@@ -88,14 +73,11 @@
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
{
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"type": "python",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
@@ -108,46 +90,16 @@
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"type": "python",
"request": "launch",
"program": "danswer/danswerbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true
}
]
}
}

View File

@@ -48,24 +48,20 @@ We would love to see you there!
## Get Started 🚀
Danswer being a fully functional app, relies on some external software, specifically:
Danswer being a fully functional app, relies on some external pieces of software, specifically:
- [Postgres](https://www.postgresql.org/) (Relational DB)
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
- [Redis](https://redis.io/) (Cache)
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
> **Note:**
> This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for
> development purposes. However, you can also use the containers and update with local changes by providing the
> `--build` flag.
This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for
development purposes but also feel free to just use the containers and update with local changes by providing the
`--build` flag.
### Local Set Up
Be sure to use Python version 3.11.
It is recommended to use Python version 3.11
If using a lower version, modifications will have to be made to the code.
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
If using a higher version, the version of Tensorflow we use may not be available for your platform.
#### Installing Requirements
@@ -76,11 +72,6 @@ For convenience here's a command for it:
python -m venv .venv
source .venv/bin/activate
```
> **Note:**
> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs.
> For simplicity, we recommend setting up the virtual environment outside of the danswer directory.
_For Windows, activate the virtual environment using Command Prompt:_
```bash
.venv\Scripts\activate
@@ -94,22 +85,19 @@ Install the required python dependencies:
```bash
pip install -r danswer/backend/requirements/default.txt
pip install -r danswer/backend/requirements/dev.txt
pip install -r danswer/backend/requirements/ee.txt
pip install -r danswer/backend/requirements/model_server.txt
```
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
Once the above is done, navigate to `danswer/web` run:
```bash
npm i
```
Install Playwright (headless browser required by the Web Connector)
Install Playwright (required by the Web Connector)
> **Note:**
> If you have just run the pip install, open a new terminal and source the python virtual-env again.
> This will pull the updated PATH to include playwright
> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again.
This will update the path to include playwright
Then install Playwright by running:
```bash
@@ -118,14 +106,11 @@ playwright install
#### Dependent Docker Containers
You will need Docker installed to run these containers.
First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
First navigate to `danswer/deployment/docker_compose`, then start up Vespa and Postgres with:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db
```
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
(index refers to Vespa and relational_db refers to Postgres)
#### Running Danswer
To start the frontend, navigate to `danswer/web` and run:
@@ -138,10 +123,11 @@ Navigate to `danswer/backend` and run:
```bash
uvicorn model_server.main:app --reload --port 9000
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
powershell -Command "
uvicorn model_server.main:app --reload --port 9000
"
```
The first time running Danswer, you will need to run the DB migrations for Postgres.
@@ -164,7 +150,6 @@ To run the backend API server, navigate back to `danswer/backend` and run:
```bash
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
@@ -173,28 +158,20 @@ powershell -Command "
"
```
> **Note:**
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
Note: if you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
### Formatting and Linting
#### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
```bash
pip install pre-commit
```
Then, from the `danswer/backend` directory, run:
```bash
pre-commit install
```
Additionally, we use `mypy` for static type checking.
Danswer is fully type-annotated, and we want to keep it that way!
Danswer is fully type-annotated, and we would like to keep it that way!
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
@@ -205,7 +182,6 @@ Please double check that prettier passes before creating a pull request.
### Release Process
Danswer loosely follows the SemVer versioning standard.
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
Danswer follows the semver versioning standard.
A set of Docker containers will be pushed automatically to DockerHub with every tag.
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).

View File

@@ -1,31 +0,0 @@
## Some additional notes for Mac Users
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md).
### Setting up Python
Ensure [Homebrew](https://brew.sh/) is already set up.
Then install python 3.11.
```bash
brew install python@3.11
```
Add python 3.11 to your path: add the following line to ~/.zshrc
```
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
```
> **Note:**
> You will need to open a new terminal for the path change above to take effect.
### Setting up Docker
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
ensure it is running before continuing with the docker commands.
### Formatting and Linting
MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly.
After installing pre-commit, run the following command:
```bash
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
```

View File

@@ -1,10 +1,6 @@
Copyright (c) 2023-present DanswerAI, Inc.
MIT License
Portions of this software are licensed as follows:
* All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
* All third party components incorporated into the Danswer Software are licensed under the original license provided by the owner of the applicable component.
* Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
Copyright (c) 2023 Yuhong Sun, Chris Weaver
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -11,7 +11,7 @@
<a href="https://docs.danswer.dev/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -105,25 +105,5 @@ Efficiently pulls the latest changes from:
* Websites
* And more ...
## 📚 Editions
There are two editions of Danswer:
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
* Single Sign-On (SSO), with support for both SAML and OIDC
* Role-based access control
* Document permission inheritance from connected sources
* Usage analytics and query history accessible to admins
* Whitelabeling
* API key authentication
* Encryption of secrets
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
To try the Danswer Enterprise Edition:
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

2
backend/.gitignore vendored
View File

@@ -5,7 +5,7 @@ site_crawls/
.ipynb_checkpoints/
api_keys.py
*ipynb
.env*
.env
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule*

View File

@@ -1,17 +1,15 @@
FROM python:3.11.7-slim-bookworm
LABEL com.danswer.maintainer="founders@danswer.ai"
LABEL com.danswer.description="This image is the web/frontend container of Danswer which \
contains code for both the Community and Enterprise editions of Danswer. If you do not \
have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \
Edition features outside of personal development or testing purposes. Please reach out to \
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
LABEL com.danswer.description="This image is for the backend of Danswer. It is MIT Licensed and \
free for all to use. You can find it at https://hub.docker.com/r/danswer/danswer-backend. For \
more details, visit https://github.com/danswer-ai/danswer."
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
# cmake needed for psycopg (postgres)
# libpq-dev needed for psycopg (postgres)
@@ -19,32 +17,18 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# zip for Vespa step futher down
# ca-certificates for HTTPS
RUN apt-get update && \
apt-get install -y \
cmake \
curl \
zip \
ca-certificates \
libgnutls30=3.7.9-2+deb12u3 \
libblkid1=2.38.1-5+deb12u1 \
libmount1=2.38.1-5+deb12u1 \
libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 \
libxmlsec1-dev \
pkg-config \
gcc && \
apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 \
libblkid1=2.38.1-5+deb12u1 libmount1=2.38.1-5+deb12u1 libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
COPY ./requirements/default.txt /tmp/requirements.txt
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
RUN pip install --no-cache-dir --upgrade \
-r /tmp/requirements.txt \
-r /tmp/ee-requirements.txt && \
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
pip uninstall -y py && \
playwright install chromium && \
playwright install-deps chromium && \
playwright install chromium && playwright install-deps chromium && \
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
# Cleanup for CVEs and size reduction
@@ -52,52 +36,29 @@ RUN pip install --no-cache-dir --upgrade \
# xserver-common and xvfb included by playwright installation but not needed after
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
# perl-base could only be removed with --allow-remove-essential
RUN apt-get update && \
apt-get remove -y --allow-remove-essential \
perl-base \
xserver-common \
xvfb \
cmake \
libldap-2.5-0 \
libxmlsec1-dev \
pkg-config \
gcc && \
apt-get install -y libxmlsec1-openssl && \
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \
libldap-2.5-0 libldap-2.5-0 && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \
nltk.download('wordnet', quiet=True); \
nltk.download('punkt', quiet=True);"
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
# Set up application files
WORKDIR /app
# Enterprise Version Files
COPY ./ee /app/ee
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
# Set up application files
COPY ./danswer /app/danswer
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
# Escape hatch
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
# Put logo in assets
COPY ./assets /app/assets
ENV PYTHONPATH /app
# Default command which does nothing

View File

@@ -18,22 +18,14 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
apt-get autoremove -y
# Pre-downloading models for setups with limited egress
# Download tokenizers, distilbert for the Danswer model
# Download model weights
# Run Nomic to pull in the custom architecture and have it cached locally
RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
# running Danswer, don't overwrite it with the built in cache folder
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
AutoTokenizer.from_pretrained('danswer/intent-model'); \
AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
snapshot_download('danswer/intent-model'); \
snapshot_download('intfloat/e5-base-v2'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')"
WORKDIR /app

View File

@@ -8,7 +8,6 @@ from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
from sqlalchemy.schema import SchemaItem
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
@@ -16,9 +15,7 @@ config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
@@ -32,20 +29,6 @@ target_metadata = [Base.metadata, ResultModelBase.metadata]
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
@@ -72,11 +55,7 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
) # type: ignore
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
with context.begin_transaction():
context.run_migrations()

View File

@@ -1,27 +0,0 @@
"""Add thread specific model selection
Revision ID: 0568ccf46a6b
Revises: e209dc5a8156
Create Date: 2024-06-19 14:25:36.376046
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0568ccf46a6b"
down_revision = "e209dc5a8156"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"chat_session",
sa.Column("current_alternate_model", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_session", "current_alternate_model")

View File

@@ -1,32 +0,0 @@
"""add search doc relevance details
Revision ID: 05c07bf07c00
Revises: b896bbd0d5a7
Create Date: 2024-07-10 17:48:15.886653
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "05c07bf07c00"
down_revision = "b896bbd0d5a7"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"search_doc",
sa.Column("is_relevant", sa.Boolean(), nullable=True),
)
op.add_column(
"search_doc",
sa.Column("relevance_explanation", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("search_doc", "relevance_explanation")
op.drop_column("search_doc", "is_relevant")

View File

@@ -1,26 +0,0 @@
"""add_indexing_start_to_connector
Revision ID: 08a1eda20fe1
Revises: 8a87bd6ec550
Create Date: 2024-07-23 11:12:39.462397
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "08a1eda20fe1"
down_revision = "8a87bd6ec550"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"connector", sa.Column("indexing_start", sa.DateTime(), nullable=True)
)
def downgrade() -> None:
op.drop_column("connector", "indexing_start")

View File

@@ -1,135 +0,0 @@
"""embedding model -> search settings
Revision ID: 1f60f60c3401
Revises: f17bf3b0d9f1
Create Date: 2024-08-25 12:39:51.731632
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
# revision identifiers, used by Alembic.
revision = "1f60f60c3401"
down_revision = "f17bf3b0d9f1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_constraint(
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
)
# Rename the table
op.rename_table("embedding_model", "search_settings")
# Add new columns
op.add_column(
"search_settings",
sa.Column(
"multipass_indexing", sa.Boolean(), nullable=False, server_default="true"
),
)
op.add_column(
"search_settings",
sa.Column(
"multilingual_expansion",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
)
op.add_column(
"search_settings",
sa.Column(
"disable_rerank_for_streaming",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True)
)
op.add_column(
"search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True)
)
op.add_column(
"search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True)
)
op.add_column(
"search_settings",
sa.Column(
"num_rerank",
sa.Integer(),
nullable=False,
server_default=str(NUM_POSTPROCESSED_RESULTS),
),
)
# Add the new column as nullable initially
op.add_column(
"index_attempt", sa.Column("search_settings_id", sa.Integer(), nullable=True)
)
# Populate the new column with data from the existing embedding_model_id
op.execute("UPDATE index_attempt SET search_settings_id = embedding_model_id")
# Create the foreign key constraint
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
)
# Make the new column non-nullable
op.alter_column("index_attempt", "search_settings_id", nullable=False)
# Drop the old embedding_model_id column
op.drop_column("index_attempt", "embedding_model_id")
def downgrade() -> None:
# Add back the embedding_model_id column
op.add_column(
"index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True)
)
# Populate the old column with data from search_settings_id
op.execute("UPDATE index_attempt SET embedding_model_id = search_settings_id")
# Make the old column non-nullable
op.alter_column("index_attempt", "embedding_model_id", nullable=False)
# Drop the foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Drop the new search_settings_id column
op.drop_column("index_attempt", "search_settings_id")
# Rename the table back
op.rename_table("search_settings", "embedding_model")
# Remove added columns
op.drop_column("embedding_model", "num_rerank")
op.drop_column("embedding_model", "rerank_api_key")
op.drop_column("embedding_model", "rerank_provider_type")
op.drop_column("embedding_model", "rerank_model_name")
op.drop_column("embedding_model", "disable_rerank_for_streaming")
op.drop_column("embedding_model", "multilingual_expansion")
op.drop_column("embedding_model", "multipass_indexing")
op.create_foreign_key(
"index_attempt__embedding_model_fk",
"index_attempt",
"embedding_model",
["embedding_model_id"],
["id"],
)

View File

@@ -1,44 +0,0 @@
"""notifications
Revision ID: 213fd978c6d8
Revises: 5fc1f54cc252
Create Date: 2024-08-10 11:13:36.070790
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "213fd978c6d8"
down_revision = "5fc1f54cc252"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"notification",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"notif_type",
sa.String(),
nullable=False,
),
sa.Column(
"user_id",
sa.UUID(),
nullable=True,
),
sa.Column("dismissed", sa.Boolean(), nullable=False),
sa.Column("last_shown", sa.DateTime(timezone=True), nullable=False),
sa.Column("first_shown", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("notification")

View File

@@ -1,86 +0,0 @@
"""remove-feedback-foreignkey-constraint
Revision ID: 23957775e5f5
Revises: bc9771dccadf
Create Date: 2024-06-27 16:04:51.480437
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "23957775e5f5"
down_revision = "bc9771dccadf"
branch_labels = None # type: ignore
depends_on = None # type: ignore
def upgrade() -> None:
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)
op.create_foreign_key(
"chat_feedback__chat_message_fk",
"chat_feedback",
"chat_message",
["chat_message_id"],
["id"],
ondelete="SET NULL",
)
op.alter_column(
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=True
)
op.drop_constraint(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
type_="foreignkey",
)
op.create_foreign_key(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
"chat_message",
["chat_message_id"],
["id"],
ondelete="SET NULL",
)
op.alter_column(
"document_retrieval_feedback",
"chat_message_id",
existing_type=sa.Integer(),
nullable=True,
)
def downgrade() -> None:
op.alter_column(
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=False
)
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)
op.create_foreign_key(
"chat_feedback__chat_message_fk",
"chat_feedback",
"chat_message",
["chat_message_id"],
["id"],
)
op.alter_column(
"document_retrieval_feedback",
"chat_message_id",
existing_type=sa.Integer(),
nullable=False,
)
op.drop_constraint(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
type_="foreignkey",
)
op.create_foreign_key(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
"chat_message",
["chat_message_id"],
["id"],
)

View File

@@ -160,28 +160,12 @@ def downgrade() -> None:
nullable=False,
),
)
# Check if the constraint exists before dropping
conn = op.get_bind()
inspector = sa.inspect(conn)
constraints = inspector.get_foreign_keys("index_attempt")
if any(
constraint["name"] == "fk_index_attempt_credential_id"
for constraint in constraints
):
op.drop_constraint(
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
)
if any(
constraint["name"] == "fk_index_attempt_connector_id"
for constraint in constraints
):
op.drop_constraint(
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
op.drop_constraint(
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
)
op.drop_constraint(
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
op.drop_column("index_attempt", "credential_id")
op.drop_column("index_attempt", "connector_id")
op.drop_table("connector_credential_pair")

View File

@@ -1,32 +0,0 @@
"""Add Above Below to Persona
Revision ID: 2d2304e27d8c
Revises: 4b08d97e175a
Create Date: 2024-08-21 19:15:15.762948
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2d2304e27d8c"
down_revision = "4b08d97e175a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("persona", sa.Column("chunks_above", sa.Integer(), nullable=True))
op.add_column("persona", sa.Column("chunks_below", sa.Integer(), nullable=True))
op.execute(
"UPDATE persona SET chunks_above = 1, chunks_below = 1 WHERE chunks_above IS NULL AND chunks_below IS NULL"
)
op.alter_column("persona", "chunks_above", nullable=False)
op.alter_column("persona", "chunks_below", nullable=False)
def downgrade() -> None:
op.drop_column("persona", "chunks_below")
op.drop_column("persona", "chunks_above")

View File

@@ -1,70 +0,0 @@
"""Add icon_color and icon_shape to Persona
Revision ID: 325975216eb3
Revises: 91ffac7e65b3
Create Date: 2024-07-24 21:29:31.784562
"""
import random
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column, select
# revision identifiers, used by Alembic.
revision = "325975216eb3"
down_revision = "91ffac7e65b3"
branch_labels: None = None
depends_on: None = None
colorOptions = [
"#FF6FBF",
"#6FB1FF",
"#B76FFF",
"#FFB56F",
"#6FFF8D",
"#FF6F6F",
"#6FFFFF",
]
# Function to generate a random shape ensuring at least 3 of the middle 4 squares are filled
def generate_random_shape() -> int:
center_squares = [12, 10, 6, 14, 13, 11, 7, 15]
center_fill = random.choice(center_squares)
remaining_squares = [i for i in range(16) if not (center_fill & (1 << i))]
random.shuffle(remaining_squares)
for i in range(10 - bin(center_fill).count("1")):
center_fill |= 1 << remaining_squares[i]
return center_fill
def upgrade() -> None:
op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True))
op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True))
op.add_column("persona", sa.Column("uploaded_image_id", sa.String(), nullable=True))
persona = table(
"persona",
column("id", sa.Integer),
column("icon_color", sa.String),
column("icon_shape", sa.Integer),
)
conn = op.get_bind()
personas = conn.execute(select(persona.c.id))
for persona_id in personas:
random_color = random.choice(colorOptions)
random_shape = generate_random_shape()
conn.execute(
persona.update()
.where(persona.c.id == persona_id[0])
.values(icon_color=random_color, icon_shape=random_shape)
)
def downgrade() -> None:
op.drop_column("persona", "icon_shape")
op.drop_column("persona", "uploaded_image_id")
op.drop_column("persona", "icon_color")

View File

@@ -1,90 +0,0 @@
"""Add curator fields
Revision ID: 351faebd379d
Revises: ee3f4b47fad5
Create Date: 2024-08-15 22:37:08.397052
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "351faebd379d"
down_revision = "ee3f4b47fad5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add is_curator column to User__UserGroup table
op.add_column(
"user__user_group",
sa.Column("is_curator", sa.Boolean(), nullable=False, server_default="false"),
)
# Use batch mode to modify the enum type
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC",
"ADMIN",
"CURATOR",
"GLOBAL_CURATOR",
name="userrole",
native_enum=False,
),
existing_type=sa.Enum("BASIC", "ADMIN", name="userrole", native_enum=False),
existing_nullable=False,
)
# Create the association table
op.create_table(
"credential__user_group",
sa.Column("credential_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["credential_id"],
["credential.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("credential_id", "user_group_id"),
)
op.add_column(
"credential",
sa.Column(
"curator_public", sa.Boolean(), nullable=False, server_default="false"
),
)
def downgrade() -> None:
# Update existing records to ensure they fit within the BASIC/ADMIN roles
op.execute(
"UPDATE \"user\" SET role = 'ADMIN' WHERE role IN ('CURATOR', 'GLOBAL_CURATOR')"
)
# Remove is_curator column from User__UserGroup table
op.drop_column("user__user_group", "is_curator")
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
),
existing_type=sa.Enum(
"BASIC",
"ADMIN",
"CURATOR",
"GLOBAL_CURATOR",
name="userrole",
native_enum=False,
),
existing_nullable=False,
)
# Drop the association table
op.drop_table("credential__user_group")
op.drop_column("credential", "curator_public")

View File

@@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3879338f8ba1"
down_revision = "f1c6478c3fd8"
branch_labels: None = None
depends_on: None = None
branch_labels = None
depends_on = None
def upgrade() -> None:

View File

@@ -1,35 +0,0 @@
"""add alternate assistant to chat message
Revision ID: 3a7802814195
Revises: 23957775e5f5
Create Date: 2024-06-05 11:18:49.966333
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3a7802814195"
down_revision = "23957775e5f5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_chat_message_persona",
"chat_message",
"persona",
["alternate_assistant_id"],
["id"],
)
def downgrade() -> None:
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "alternate_assistant_id")

View File

@@ -1,42 +0,0 @@
"""Rename index_origin to index_recursively
Revision ID: 1d6ad76d1f37
Revises: e1392f05e840
Create Date: 2024-08-01 12:38:54.466081
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1d6ad76d1f37"
down_revision = "e1392f05e840"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE connector
SET connector_specific_config = jsonb_set(
connector_specific_config,
'{index_recursively}',
'true'::jsonb
) - 'index_origin'
WHERE connector_specific_config ? 'index_origin'
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE connector
SET connector_specific_config = jsonb_set(
connector_specific_config,
'{index_origin}',
connector_specific_config->'index_recursively'
) - 'index_recursively'
WHERE connector_specific_config ? 'index_recursively'
"""
)

View File

@@ -1,65 +0,0 @@
"""add cloud embedding model and update embedding_model
Revision ID: 44f856ae2a4a
Revises: d716b0791ddd
Create Date: 2024-06-28 20:01:05.927647
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "44f856ae2a4a"
down_revision = "d716b0791ddd"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Create embedding_provider table
op.create_table(
"embedding_provider",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("default_model_id", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add cloud_provider_id to embedding_model table
op.add_column(
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
)
# Add foreign key constraints
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["cloud_provider_id"],
["id"],
)
op.create_foreign_key(
"fk_embedding_provider_default_model",
"embedding_provider",
"embedding_model",
["default_model_id"],
["id"],
)
def downgrade() -> None:
# Remove foreign key constraints
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
op.drop_constraint(
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
)
# Remove cloud_provider_id column
op.drop_column("embedding_model", "cloud_provider_id")
# Drop embedding_provider table
op.drop_table("embedding_provider")

View File

@@ -1,23 +0,0 @@
"""added is_internet to DBDoc
Revision ID: 4505fd7302e1
Revises: c18cdf4b497e
Create Date: 2024-06-18 20:46:09.095034
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4505fd7302e1"
down_revision = "c18cdf4b497e"
def upgrade() -> None:
op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True))
op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("tool", "display_name")
op.drop_column("search_doc", "is_internet")

View File

@@ -1,49 +0,0 @@
"""Add display_model_names to llm_provider
Revision ID: 473a1a7ca408
Revises: 325975216eb3
Create Date: 2024-07-25 14:31:02.002917
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "473a1a7ca408"
down_revision = "325975216eb3"
branch_labels: None = None
depends_on: None = None
default_models_by_provider = {
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],
"bedrock": [
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"mistral.mistral-large-2402-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
],
"anthropic": ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620"],
}
def upgrade() -> None:
op.add_column(
"llm_provider",
sa.Column("display_model_names", postgresql.ARRAY(sa.String()), nullable=True),
)
connection = op.get_bind()
for provider, models in default_models_by_provider.items():
connection.execute(
sa.text(
"UPDATE llm_provider SET display_model_names = :models WHERE provider = :provider"
),
{"models": models, "provider": provider},
)
def downgrade() -> None:
op.drop_column("llm_provider", "display_model_names")

View File

@@ -1,61 +0,0 @@
"""Add support for custom tools
Revision ID: 48d14957fe80
Revises: b85f02ec1308
Create Date: 2024-06-09 14:58:19.946509
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "48d14957fe80"
down_revision = "b85f02ec1308"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"tool",
sa.Column(
"openapi_schema",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.add_column(
"tool",
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
)
op.create_foreign_key("tool_user_fk", "tool", "user", ["user_id"], ["id"])
op.create_table(
"tool_call",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("tool_id", sa.Integer(), nullable=False),
sa.Column("tool_name", sa.String(), nullable=False),
sa.Column(
"tool_arguments", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column(
"tool_result", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column(
"message_id", sa.Integer(), sa.ForeignKey("chat_message.id"), nullable=False
),
)
def downgrade() -> None:
op.drop_table("tool_call")
op.drop_constraint("tool_user_fk", "tool", type_="foreignkey")
op.drop_column("tool", "user_id")
op.drop_column("tool", "openapi_schema")

View File

@@ -1,80 +0,0 @@
"""Moved status to connector credential pair
Revision ID: 4a951134c801
Revises: 7477a5f5d728
Create Date: 2024-08-10 19:20:34.527559
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4a951134c801"
down_revision = "7477a5f5d728"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"status",
sa.Enum(
"ACTIVE",
"PAUSED",
"DELETING",
name="connectorcredentialpairstatus",
native_enum=False,
),
nullable=True,
),
)
# Update status of connector_credential_pair based on connector's disabled status
op.execute(
"""
UPDATE connector_credential_pair
SET status = CASE
WHEN (
SELECT disabled
FROM connector
WHERE connector.id = connector_credential_pair.connector_id
) = FALSE THEN 'ACTIVE'
ELSE 'PAUSED'
END
"""
)
# Make the status column not nullable after setting values
op.alter_column("connector_credential_pair", "status", nullable=False)
op.drop_column("connector", "disabled")
def downgrade() -> None:
op.add_column(
"connector",
sa.Column("disabled", sa.BOOLEAN(), autoincrement=False, nullable=True),
)
# Update disabled status of connector based on connector_credential_pair's status
op.execute(
"""
UPDATE connector
SET disabled = CASE
WHEN EXISTS (
SELECT 1
FROM connector_credential_pair
WHERE connector_credential_pair.connector_id = connector.id
AND connector_credential_pair.status = 'ACTIVE'
) THEN FALSE
ELSE TRUE
END
"""
)
# Make the disabled column not nullable after setting values
op.alter_column("connector", "disabled", nullable=False)
op.drop_column("connector_credential_pair", "status")

View File

@@ -1,34 +0,0 @@
"""change default prune_freq
Revision ID: 4b08d97e175a
Revises: d9ec13955951
Create Date: 2024-08-20 15:28:52.993827
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "4b08d97e175a"
down_revision = "d9ec13955951"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE connector
SET prune_freq = 2592000
WHERE prune_freq = 86400
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE connector
SET prune_freq = 86400
WHERE prune_freq = 2592000
"""
)

View File

@@ -1,72 +0,0 @@
"""Add type to credentials
Revision ID: 4ea2c93919c1
Revises: 473a1a7ca408
Create Date: 2024-07-18 13:07:13.655895
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4ea2c93919c1"
down_revision = "473a1a7ca408"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add the new 'source' column to the 'credential' table
op.add_column(
"credential",
sa.Column(
"source",
sa.String(length=100), # Use String instead of Enum
nullable=True, # Initially allow NULL values
),
)
op.add_column(
"credential",
sa.Column(
"name",
sa.String(),
nullable=True,
),
)
# Create a temporary table that maps each credential to a single connector source.
# This is needed because a credential can be associated with multiple connectors,
# but we want to assign a single source to each credential.
# We use DISTINCT ON to ensure we only get one row per credential_id.
op.execute(
"""
CREATE TEMPORARY TABLE temp_connector_credential AS
SELECT DISTINCT ON (cc.credential_id)
cc.credential_id,
c.source AS connector_source
FROM connector_credential_pair cc
JOIN connector c ON cc.connector_id = c.id
"""
)
# Update the 'source' column in the 'credential' table
op.execute(
"""
UPDATE credential cred
SET source = COALESCE(
(SELECT connector_source
FROM temp_connector_credential temp
WHERE cred.id = temp.credential_id),
'NOT_APPLICABLE'
)
"""
)
# If no exception was raised, alter the column
op.alter_column("credential", "source", nullable=True) # TODO modify
# # ### end Alembic commands ###
def downgrade() -> None:
op.drop_column("credential", "source")
op.drop_column("credential", "name")

View File

@@ -1,66 +0,0 @@
"""Add last synced and last modified to document table
Revision ID: 52a219fb5233
Revises: f17bf3b0d9f1
Create Date: 2024-08-28 17:40:46.077470
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import func
# revision identifiers, used by Alembic.
revision = "52a219fb5233"
down_revision = "f7e58d357687"
branch_labels = None
depends_on = None
def upgrade() -> None:
# last modified represents the last time anything needing syncing to vespa changed
# including row metadata and the document itself. This obviously does not include
# the last_synced column.
op.add_column(
"document",
sa.Column(
"last_modified",
sa.DateTime(timezone=True),
nullable=False,
server_default=func.now(),
),
)
# last synced represents the last time this document was synced to Vespa
op.add_column(
"document",
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True),
)
# Set last_synced to the same value as last_modified for existing rows
op.execute(
"""
UPDATE document
SET last_synced = last_modified
"""
)
op.create_index(
op.f("ix_document_last_modified"),
"document",
["last_modified"],
unique=False,
)
op.create_index(
op.f("ix_document_last_synced"),
"document",
["last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_document_last_synced"), table_name="document")
op.drop_index(op.f("ix_document_last_modified"), table_name="document")
op.drop_column("document", "last_synced")
op.drop_column("document", "last_modified")

View File

@@ -1,25 +0,0 @@
"""hybrid-enum
Revision ID: 5fc1f54cc252
Revises: 1d6ad76d1f37
Create Date: 2024-08-06 15:35:40.278485
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5fc1f54cc252"
down_revision = "1d6ad76d1f37"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_column("persona", "search_type")
def downgrade() -> None:
op.add_column("persona", sa.Column("search_type", sa.String(), nullable=True))
op.execute("UPDATE persona SET search_type = 'SEMANTIC'")
op.alter_column("persona", "search_type", nullable=False)

View File

@@ -1,68 +0,0 @@
"""More Descriptive Filestore
Revision ID: 70f00c45c0f2
Revises: 3879338f8ba1
Create Date: 2024-05-17 17:51:41.926893
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "70f00c45c0f2"
down_revision = "3879338f8ba1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("file_store", sa.Column("display_name", sa.String(), nullable=True))
op.add_column(
"file_store",
sa.Column(
"file_origin",
sa.String(),
nullable=False,
server_default="connector", # Default to connector
),
)
op.add_column(
"file_store",
sa.Column(
"file_type", sa.String(), nullable=False, server_default="text/plain"
),
)
op.add_column(
"file_store",
sa.Column(
"file_metadata",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.execute(
"""
UPDATE file_store
SET file_origin = CASE
WHEN file_name LIKE 'chat__%' THEN 'chat_upload'
ELSE 'connector'
END,
file_name = CASE
WHEN file_name LIKE 'chat__%' THEN SUBSTR(file_name, 7)
ELSE file_name
END,
file_type = CASE
WHEN file_name LIKE 'chat__%' THEN 'image/png'
ELSE 'text/plain'
END
"""
)
def downgrade() -> None:
op.drop_column("file_store", "file_metadata")
op.drop_column("file_store", "file_type")
op.drop_column("file_store", "file_origin")
op.drop_column("file_store", "display_name")

View File

@@ -1,24 +0,0 @@
"""Added model defaults for users
Revision ID: 7477a5f5d728
Revises: 213fd978c6d8
Create Date: 2024-08-04 19:00:04.512634
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7477a5f5d728"
down_revision = "213fd978c6d8"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("user", sa.Column("default_model", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("user", "default_model")

View File

@@ -28,9 +28,5 @@ def upgrade() -> None:
def downgrade() -> None:
op.create_unique_constraint(
"connector_credential_pair__name__key", "connector_credential_pair", ["name"]
)
op.alter_column(
"connector_credential_pair", "name", existing_type=sa.String(), nullable=True
)
# This wasn't really required by the code either, no good reason to make it unique again
pass

View File

@@ -10,7 +10,7 @@ import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -1,41 +0,0 @@
"""add_llm_group_permissions_control
Revision ID: 795b20b85b4b
Revises: 05c07bf07c00
Create Date: 2024-07-19 11:54:35.701558
"""
from alembic import op
import sqlalchemy as sa
revision = "795b20b85b4b"
down_revision = "05c07bf07c00"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"llm_provider__user_group",
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["llm_provider_id"],
["llm_provider.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
)
op.add_column(
"llm_provider",
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_table("llm_provider__user_group")
op.drop_column("llm_provider", "is_public")

View File

@@ -1,35 +0,0 @@
"""added slack_auto_filter
Revision ID: 7aea705850d5
Revises: 4505fd7302e1
Create Date: 2024-07-10 11:01:23.581015
"""
from alembic import op
import sqlalchemy as sa
revision = "7aea705850d5"
down_revision = "4505fd7302e1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"slack_bot_config",
sa.Column("enable_auto_filters", sa.Boolean(), nullable=True),
)
op.execute(
"UPDATE slack_bot_config SET enable_auto_filters = FALSE WHERE enable_auto_filters IS NULL"
)
op.alter_column(
"slack_bot_config",
"enable_auto_filters",
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.false(),
)
def downgrade() -> None:
op.drop_column("slack_bot_config", "enable_auto_filters")

View File

@@ -1,107 +0,0 @@
"""associate index attempts with ccpair
Revision ID: 8a87bd6ec550
Revises: 4ea2c93919c1
Create Date: 2024-07-22 15:15:52.558451
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8a87bd6ec550"
down_revision = "4ea2c93919c1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add the new connector_credential_pair_id column
op.add_column(
"index_attempt",
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
)
# Create a foreign key constraint to the connector_credential_pair table
op.create_foreign_key(
"fk_index_attempt_connector_credential_pair_id",
"index_attempt",
"connector_credential_pair",
["connector_credential_pair_id"],
["id"],
)
# Populate the new connector_credential_pair_id column using existing connector_id and credential_id
op.execute(
"""
UPDATE index_attempt ia
SET connector_credential_pair_id = (
SELECT id FROM connector_credential_pair ccp
WHERE
(ia.connector_id IS NULL OR ccp.connector_id = ia.connector_id)
AND (ia.credential_id IS NULL OR ccp.credential_id = ia.credential_id)
LIMIT 1
)
WHERE ia.connector_id IS NOT NULL OR ia.credential_id IS NOT NULL
"""
)
# For good measure
op.execute(
"""
DELETE FROM index_attempt
WHERE connector_credential_pair_id IS NULL
"""
)
# Make the new connector_credential_pair_id column non-nullable
op.alter_column("index_attempt", "connector_credential_pair_id", nullable=False)
# Drop the old connector_id and credential_id columns
op.drop_column("index_attempt", "connector_id")
op.drop_column("index_attempt", "credential_id")
# Update the index to use connector_credential_pair_id
op.create_index(
"ix_index_attempt_latest_for_connector_credential_pair",
"index_attempt",
["connector_credential_pair_id", "time_created"],
)
def downgrade() -> None:
# Add back the old connector_id and credential_id columns
op.add_column(
"index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True)
)
op.add_column(
"index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True)
)
# Populate the old connector_id and credential_id columns using the connector_credential_pair_id
op.execute(
"""
UPDATE index_attempt ia
SET connector_id = ccp.connector_id, credential_id = ccp.credential_id
FROM connector_credential_pair ccp
WHERE ia.connector_credential_pair_id = ccp.id
"""
)
# Make the old connector_id and credential_id columns non-nullable
op.alter_column("index_attempt", "connector_id", nullable=False)
op.alter_column("index_attempt", "credential_id", nullable=False)
# Drop the new connector_credential_pair_id column
op.drop_constraint(
"fk_index_attempt_connector_credential_pair_id",
"index_attempt",
type_="foreignkey",
)
op.drop_column("index_attempt", "connector_credential_pair_id")
op.create_index(
"ix_index_attempt_latest_for_connector_credential_pair",
"index_attempt",
["connector_id", "credential_id", "time_created"],
)

View File

@@ -1,26 +0,0 @@
"""add expiry time
Revision ID: 91ffac7e65b3
Revises: bc9771dccadf
Create Date: 2024-06-24 09:39:56.462242
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "91ffac7e65b3"
down_revision = "795b20b85b4b"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
)
def downgrade() -> None:
op.drop_column("user", "oidc_expiry")

View File

@@ -1,158 +0,0 @@
"""migration confluence to be explicit
Revision ID: a3795dce87be
Revises: 1f60f60c3401
Create Date: 2024-09-01 13:52:12.006740
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import table, column
revision = "a3795dce87be"
down_revision = "1f60f60c3401"
branch_labels: None = None
depends_on: None = None
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
from urllib.parse import urlparse
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
parsed_url = urlparse(wiki_url)
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/spaces')[0]}"
path_parts = parsed_url.path.split("/")
space = path_parts[3]
page_id = path_parts[5] if len(path_parts) > 5 else ""
return wiki_base, space, page_id
def _extract_confluence_keys_from_datacenter_url(
wiki_url: str,
) -> tuple[str, str, str]:
DISPLAY = "/display/"
PAGE = "/pages/"
parsed_url = urlparse(wiki_url)
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split(DISPLAY)[0]}"
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
page_id = ""
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
page_id = content[1]
return wiki_base, space, page_id
is_confluence_cloud = (
".atlassian.net/wiki/spaces/" in wiki_url
or ".jira.com/wiki/spaces/" in wiki_url
)
if is_confluence_cloud:
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(wiki_url)
else:
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
wiki_url
)
return wiki_base, space, page_id, is_confluence_cloud
def reconstruct_confluence_url(
wiki_base: str, space: str, page_id: str, is_cloud: bool
) -> str:
if is_cloud:
url = f"{wiki_base}/spaces/{space}"
if page_id:
url += f"/pages/{page_id}"
else:
url = f"{wiki_base}/display/{space}"
if page_id:
url += f"/pages/{page_id}"
return url
def upgrade() -> None:
connector = table(
"connector",
column("id", sa.Integer),
column("source", sa.String()),
column("input_type", sa.String()),
column("connector_specific_config", postgresql.JSONB),
)
# Fetch all Confluence connectors
connection = op.get_bind()
confluence_connectors = connection.execute(
sa.select(connector).where(
sa.and_(
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
)
)
).fetchall()
for row in confluence_connectors:
config = row.connector_specific_config
wiki_page_url = config["wiki_page_url"]
wiki_base, space, page_id, is_cloud = extract_confluence_keys_from_url(
wiki_page_url
)
new_config = {
"wiki_base": wiki_base,
"space": space,
"page_id": page_id,
"is_cloud": is_cloud,
}
for key, value in config.items():
if key not in ["wiki_page_url"]:
new_config[key] = value
op.execute(
connector.update()
.where(connector.c.id == row.id)
.values(connector_specific_config=new_config)
)
def downgrade() -> None:
connector = table(
"connector",
column("id", sa.Integer),
column("source", sa.String()),
column("input_type", sa.String()),
column("connector_specific_config", postgresql.JSONB),
)
confluence_connectors = (
op.get_bind()
.execute(
sa.select(connector).where(
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
)
)
.fetchall()
)
for row in confluence_connectors:
config = row.connector_specific_config
if all(key in config for key in ["wiki_base", "space", "is_cloud"]):
wiki_page_url = reconstruct_confluence_url(
config["wiki_base"],
config["space"],
config.get("page_id", ""),
config["is_cloud"],
)
new_config = {"wiki_page_url": wiki_page_url}
new_config.update(
{
k: v
for k, v in config.items()
if k not in ["wiki_base", "space", "page_id", "is_cloud"]
}
)
op.execute(
connector.update()
.where(connector.c.id == row.id)
.values(connector_specific_config=new_config)
)

View File

@@ -1,27 +0,0 @@
"""Add chosen_assistants to User table
Revision ID: a3bfd0d64902
Revises: ec85f2b3c544
Create Date: 2024-05-26 17:22:24.834741
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a3bfd0d64902"
down_revision = "ec85f2b3c544"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)
def downgrade() -> None:
op.drop_column("user", "chosen_assistants")

View File

@@ -16,6 +16,7 @@ depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"connector_credential_pair",
"last_attempt_status",
@@ -28,9 +29,11 @@ def upgrade() -> None:
),
nullable=True,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"connector_credential_pair",
"last_attempt_status",
@@ -43,3 +46,4 @@ def downgrade() -> None:
),
nullable=False,
)
# ### end Alembic commands ###

View File

@@ -1,28 +0,0 @@
"""fix-file-type-migration
Revision ID: b85f02ec1308
Revises: a3bfd0d64902
Create Date: 2024-05-31 18:09:26.658164
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b85f02ec1308"
down_revision = "a3bfd0d64902"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE file_store
SET file_origin = UPPER(file_origin)
"""
)
def downgrade() -> None:
# Let's not break anything on purpose :)
pass

View File

@@ -1,23 +0,0 @@
"""backfill is_internet data to False
Revision ID: b896bbd0d5a7
Revises: 44f856ae2a4a
Create Date: 2024-07-16 15:21:05.718571
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b896bbd0d5a7"
down_revision = "44f856ae2a4a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute("UPDATE search_doc SET is_internet = FALSE WHERE is_internet IS NULL")
def downgrade() -> None:
pass

View File

@@ -1,26 +0,0 @@
"""add support for litellm proxy in reranking
Revision ID: ba98eba0f66a
Revises: bceb1e139447
Create Date: 2024-09-06 10:36:04.507332
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ba98eba0f66a"
down_revision = "bceb1e139447"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("search_settings", "rerank_api_url")

View File

@@ -1,51 +0,0 @@
"""create usage reports table
Revision ID: bc9771dccadf
Revises: 0568ccf46a6b
Create Date: 2024-06-18 10:04:26.800282
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "bc9771dccadf"
down_revision = "0568ccf46a6b"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"usage_reports",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("report_name", sa.String(), nullable=False),
sa.Column(
"requestor_user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("period_from", sa.DateTime(timezone=True), nullable=True),
sa.Column("period_to", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["report_name"],
["file_store.file_name"],
),
sa.ForeignKeyConstraint(
["requestor_user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("usage_reports")

View File

@@ -1,26 +0,0 @@
"""Add base_url to CloudEmbeddingProvider
Revision ID: bceb1e139447
Revises: a3795dce87be
Create Date: 2024-08-28 17:00:52.554580
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bceb1e139447"
down_revision = "a3795dce87be"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("embedding_provider", "api_url")

View File

@@ -1,75 +0,0 @@
"""Add standard_answer tables
Revision ID: c18cdf4b497e
Revises: 3a7802814195
Create Date: 2024-06-06 15:15:02.000648
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c18cdf4b497e"
down_revision = "3a7802814195"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"standard_answer",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("keyword", sa.String(), nullable=False),
sa.Column("answer", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("keyword"),
)
op.create_table(
"standard_answer_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
op.create_table(
"standard_answer__standard_answer_category",
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["standard_answer_category_id"],
["standard_answer_category.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_id"],
["standard_answer.id"],
),
sa.PrimaryKeyConstraint("standard_answer_id", "standard_answer_category_id"),
)
op.create_table(
"slack_bot_config__standard_answer_category",
sa.Column("slack_bot_config_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["slack_bot_config_id"],
["slack_bot_config.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_category_id"],
["standard_answer_category.id"],
),
sa.PrimaryKeyConstraint("slack_bot_config_id", "standard_answer_category_id"),
)
op.add_column(
"chat_session", sa.Column("slack_thread_id", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("chat_session", "slack_thread_id")
op.drop_table("slack_bot_config__standard_answer_category")
op.drop_table("standard_answer__standard_answer_category")
op.drop_table("standard_answer_category")
op.drop_table("standard_answer")

View File

@@ -1,57 +0,0 @@
"""Add index_attempt_errors table
Revision ID: c5b692fa265c
Revises: 4a951134c801
Create Date: 2024-08-08 14:06:39.581972
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c5b692fa265c"
down_revision = "4a951134c801"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
sa.Column("batch", sa.Integer(), nullable=True),
sa.Column(
"doc_summaries",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("traceback", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"index_attempt_id",
"index_attempt_errors",
["time_created"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
op.drop_table("index_attempt_errors")
# ### end Alembic commands ###

View File

@@ -19,9 +19,6 @@ depends_on: None = None
def upgrade() -> None:
op.drop_table("deletion_attempt")
# Remove the DeletionStatus enum
op.execute("DROP TYPE IF EXISTS deletionstatus;")
def downgrade() -> None:
op.create_table(

View File

@@ -1,45 +0,0 @@
"""combined slack id fields
Revision ID: d716b0791ddd
Revises: 7aea705850d5
Create Date: 2024-07-10 17:57:45.630550
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "d716b0791ddd"
down_revision = "7aea705850d5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE slack_bot_config
SET channel_config = jsonb_set(
channel_config,
'{respond_member_group_list}',
coalesce(channel_config->'respond_team_member_list', '[]'::jsonb) ||
coalesce(channel_config->'respond_slack_group_list', '[]'::jsonb)
) - 'respond_team_member_list' - 'respond_slack_group_list'
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE slack_bot_config
SET channel_config = jsonb_set(
jsonb_set(
channel_config - 'respond_member_group_list',
'{respond_team_member_list}',
'[]'::jsonb
),
'{respond_slack_group_list}',
'[]'::jsonb
)
"""
)

View File

@@ -1,31 +0,0 @@
"""Remove _alt suffix from model_name
Revision ID: d9ec13955951
Revises: da4c21c69164
Create Date: 2024-08-20 16:31:32.955686
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "d9ec13955951"
down_revision = "da4c21c69164"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE embedding_model
SET model_name = regexp_replace(model_name, '__danswer_alt_index$', '')
WHERE model_name LIKE '%__danswer_alt_index'
"""
)
def downgrade() -> None:
# We can't reliably add the __danswer_alt_index suffix back, so we'll leave this empty
pass

View File

@@ -1,65 +0,0 @@
"""chosen_assistants changed to jsonb
Revision ID: da4c21c69164
Revises: c5b692fa265c
Create Date: 2024-08-18 19:06:47.291491
"""
import json
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "da4c21c69164"
down_revision = "c5b692fa265c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
)
op.drop_column(
"user",
"chosen_assistants",
)
op.add_column(
"user",
sa.Column(
"chosen_assistants",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
),
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
)
def downgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
)
op.drop_column(
"user",
"chosen_assistants",
)
op.add_column(
"user",
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
),
{"chosen_assistants": chosen_assistants, "id": id},
)

View File

@@ -9,7 +9,7 @@ from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Boolean
from danswer.db.search_settings import (
from danswer.db.embedding_model import (
get_new_default_embedding_model,
get_old_default_embedding_model,
user_has_overridden_embedding_model,
@@ -71,14 +71,14 @@ def upgrade() -> None:
"query_prefix": old_embedding_model.query_prefix,
"passage_prefix": old_embedding_model.passage_prefix,
"index_name": old_embedding_model.index_name,
"status": IndexModelStatus.PRESENT,
"status": old_embedding_model.status,
}
],
)
# if the user has not overridden the default embedding model via env variables,
# insert the new default model into the database to auto-upgrade them
if not user_has_overridden_embedding_model():
new_embedding_model = get_new_default_embedding_model()
new_embedding_model = get_new_default_embedding_model(is_present=False)
op.bulk_insert(
EmbeddingModel,
[
@@ -136,4 +136,4 @@ def downgrade() -> None:
)
op.drop_column("index_attempt", "embedding_model_id")
op.drop_table("embedding_model")
op.execute("DROP TYPE IF EXISTS indexmodelstatus;")
op.execute("DROP TYPE indexmodelstatus;")

View File

@@ -1,58 +0,0 @@
"""Added input prompts
Revision ID: e1392f05e840
Revises: 08a1eda20fe1
Create Date: 2024-07-13 19:09:22.556224
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e1392f05e840"
down_revision = "08a1eda20fe1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"inputprompt",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.String(), nullable=False),
sa.Column("content", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("is_public", sa.Boolean(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"inputprompt__user",
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["input_prompt_id"],
["inputprompt.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["inputprompt.id"],
),
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
)
def downgrade() -> None:
op.drop_table("inputprompt__user")
op.drop_table("inputprompt")

View File

@@ -1,22 +0,0 @@
"""added-prune-frequency
Revision ID: e209dc5a8156
Revises: 48d14957fe80
Create Date: 2024-06-16 16:02:35.273231
"""
from alembic import op
import sqlalchemy as sa
revision = "e209dc5a8156"
down_revision = "48d14957fe80"
branch_labels = None # type: ignore
depends_on = None # type: ignore
def upgrade() -> None:
op.add_column("connector", sa.Column("prune_freq", sa.Integer(), nullable=True))
def downgrade() -> None:
op.drop_column("connector", "prune_freq")

View File

@@ -1,31 +0,0 @@
"""Remove Last Attempt Status from CC Pair
Revision ID: ec85f2b3c544
Revises: 3879338f8ba1
Create Date: 2024-05-23 21:39:46.126010
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ec85f2b3c544"
down_revision = "70f00c45c0f2"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_column("connector_credential_pair", "last_attempt_status")
def downgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"last_attempt_status",
sa.VARCHAR(),
autoincrement=False,
nullable=True,
),
)

View File

@@ -1,28 +0,0 @@
"""Added alternate model to chat message
Revision ID: ee3f4b47fad5
Revises: 2d2304e27d8c
Create Date: 2024-08-12 00:11:50.915845
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ee3f4b47fad5"
down_revision = "2d2304e27d8c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column("overridden_model", sa.String(length=255), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "overridden_model")

View File

@@ -1,172 +0,0 @@
"""embedding provider by provider type
Revision ID: f17bf3b0d9f1
Revises: 351faebd379d
Create Date: 2024-08-21 13:13:31.120460
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f17bf3b0d9f1"
down_revision = "351faebd379d"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add provider_type column to embedding_provider
op.add_column(
"embedding_provider",
sa.Column("provider_type", sa.String(50), nullable=True),
)
# Update provider_type with existing name values
op.execute("UPDATE embedding_provider SET provider_type = UPPER(name)")
# Make provider_type not nullable
op.alter_column("embedding_provider", "provider_type", nullable=False)
# Drop the foreign key constraint in embedding_model table
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
# Drop the existing primary key constraint
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
# Create a new primary key constraint on provider_type
op.create_primary_key(
"embedding_provider_pkey", "embedding_provider", ["provider_type"]
)
# Add provider_type column to embedding_model
op.add_column(
"embedding_model",
sa.Column("provider_type", sa.String(50), nullable=True),
)
# Update provider_type for existing embedding models
op.execute(
"""
UPDATE embedding_model
SET provider_type = (
SELECT provider_type
FROM embedding_provider
WHERE embedding_provider.id = embedding_model.cloud_provider_id
)
"""
)
# Drop the old id column from embedding_provider
op.drop_column("embedding_provider", "id")
# Drop the name column from embedding_provider
op.drop_column("embedding_provider", "name")
# Drop the default_model_id column from embedding_provider
op.drop_column("embedding_provider", "default_model_id")
# Drop the old cloud_provider_id column from embedding_model
op.drop_column("embedding_model", "cloud_provider_id")
# Create the new foreign key constraint
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["provider_type"],
["provider_type"],
)
def downgrade() -> None:
# Drop the foreign key constraint in embedding_model table
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
# Add back the cloud_provider_id column to embedding_model
op.add_column(
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
)
op.add_column("embedding_provider", sa.Column("id", sa.Integer(), nullable=True))
# Assign incrementing IDs to embedding providers
op.execute(
"""
CREATE SEQUENCE IF NOT EXISTS embedding_provider_id_seq;"""
)
op.execute(
"""
UPDATE embedding_provider SET id = nextval('embedding_provider_id_seq');
"""
)
# Update cloud_provider_id based on provider_type
op.execute(
"""
UPDATE embedding_model
SET cloud_provider_id = CASE
WHEN provider_type IS NULL THEN NULL
ELSE (
SELECT id
FROM embedding_provider
WHERE embedding_provider.provider_type = embedding_model.provider_type
)
END
"""
)
# Drop the provider_type column from embedding_model
op.drop_column("embedding_model", "provider_type")
# Add back the columns to embedding_provider
op.add_column("embedding_provider", sa.Column("name", sa.String(50), nullable=True))
op.add_column(
"embedding_provider", sa.Column("default_model_id", sa.Integer(), nullable=True)
)
# Drop the existing primary key constraint on provider_type
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
# Create the original primary key constraint on id
op.create_primary_key("embedding_provider_pkey", "embedding_provider", ["id"])
# Update name with existing provider_type values
op.execute(
"""
UPDATE embedding_provider
SET name = CASE
WHEN provider_type = 'OPENAI' THEN 'OpenAI'
WHEN provider_type = 'COHERE' THEN 'Cohere'
WHEN provider_type = 'GOOGLE' THEN 'Google'
WHEN provider_type = 'VOYAGE' THEN 'Voyage'
ELSE provider_type
END
"""
)
# Drop the provider_type column from embedding_provider
op.drop_column("embedding_provider", "provider_type")
# Recreate the foreign key constraint in embedding_model table
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["cloud_provider_id"],
["id"],
)
# Recreate the foreign key constraint in embedding_model table
op.create_foreign_key(
"fk_embedding_provider_default_model",
"embedding_provider",
"embedding_model",
["default_model_id"],
["id"],
)

View File

@@ -1,26 +0,0 @@
"""add has_web_login column to user
Revision ID: f7e58d357687
Revises: bceb1e139447
Create Date: 2024-09-07 20:20:54.522620
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f7e58d357687"
down_revision = "ba98eba0f66a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_column("user", "has_web_login")

View File

@@ -1,2 +0,0 @@
*
!.gitignore

View File

@@ -1,51 +1,25 @@
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_user
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_access_info_for_document
from danswer.db.document import get_access_info_for_documents
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.models import User
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.variable_functionality import fetch_versioned_implementation
def _get_access_for_document(
document_id: str,
db_session: Session,
) -> DocumentAccess:
info = get_access_info_for_document(
db_session=db_session,
document_id=document_id,
)
if not info:
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2])
def get_access_for_document(
document_id: str,
db_session: Session,
) -> DocumentAccess:
versioned_get_access_for_document_fn = fetch_versioned_implementation(
"danswer.access.access", "_get_access_for_document"
)
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
def _get_access_for_documents(
document_ids: list[str],
db_session: Session,
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
) -> dict[str, DocumentAccess]:
document_access_info = get_access_info_for_documents(
document_access_info = get_acccess_info_for_documents(
db_session=db_session,
document_ids=document_ids,
cc_pair_to_delete=cc_pair_to_delete,
)
return {
document_id: DocumentAccess.build(
user_ids=user_ids, user_groups=[], is_public=is_public
)
document_id: DocumentAccess.build(user_ids, is_public)
for document_id, user_ids, is_public in document_access_info
}
@@ -53,16 +27,23 @@ def _get_access_for_documents(
def get_access_for_documents(
document_ids: list[str],
db_session: Session,
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
) -> dict[str, DocumentAccess]:
"""Fetches all access information for the given documents."""
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
"danswer.access.access", "_get_access_for_documents"
)
return versioned_get_access_for_documents_fn(
document_ids, db_session
document_ids, db_session, cc_pair_to_delete
) # type: ignore
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The

View File

@@ -1,30 +1,20 @@
from dataclasses import dataclass
from uuid import UUID
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_user_group
from danswer.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class DocumentAccess:
user_ids: set[str] # stringified UUIDs
user_groups: set[str] # names of user groups associated with this document
is_public: bool
def to_acl(self) -> list[str]:
return (
[prefix_user(user_id) for user_id in self.user_ids]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else [])
@classmethod
def build(
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
) -> "DocumentAccess":
def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess":
return cls(
user_ids={str(user_id) for user_id in user_ids if user_id},
user_groups=set(user_groups),
is_public=is_public,
)

View File

@@ -1,10 +0,0 @@
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user IDs.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"

View File

@@ -1,20 +0,0 @@
from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import JSON_ro
def get_invited_users() -> list[str]:
try:
store = get_dynamic_config_store()
return cast(list, store.load(KV_USER_STORE_KEY))
except ConfigNotFoundError:
return list()
def write_invited_users(emails: list[str]) -> int:
store = get_dynamic_config_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -1,38 +0,0 @@
from collections.abc import Mapping
from typing import Any
from typing import cast
from danswer.auth.schemas import UserRole
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
from danswer.dynamic_configs.store import ConfigNotFoundError
from danswer.dynamic_configs.store import DynamicConfigStore
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
def set_no_auth_user_preferences(
store: DynamicConfigStore, preferences: UserPreferences
) -> None:
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
try:
preferences_data = cast(
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
)
return UserPreferences(**preferences_data)
except ConfigNotFoundError:
return UserPreferences(chosen_assistants=None, default_model=None)
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
return UserInfo(
id="__no_auth_user__",
email="anonymous@danswer.ai",
is_active=True,
is_superuser=False,
is_verified=True,
role=UserRole.ADMIN,
preferences=load_no_auth_user_preferences(store),
)

View File

@@ -5,26 +5,8 @@ from fastapi_users import schemas
class UserRole(str, Enum):
"""
User roles
- Basic can't perform any admin actions
- Admin can perform all admin actions
- Curator can perform admin actions for
groups they are curators of
- Global Curator can perform admin actions
for all groups they are a member of
"""
BASIC = "basic"
ADMIN = "admin"
CURATOR = "curator"
GLOBAL_CURATOR = "global_curator"
class UserStatus(str, Enum):
LIVE = "live"
INVITED = "invited"
DEACTIVATED = "deactivated"
class UserRead(schemas.BaseUser[uuid.UUID]):
@@ -33,9 +15,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
has_web_login: bool | None = True

View File

@@ -1,24 +1,19 @@
import os
import smtplib
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Optional
from typing import Tuple
from email_validator import EmailNotValidError
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager
from fastapi_users import exceptions
from fastapi_users import FastAPIUsers
from fastapi_users import models
from fastapi_users import schemas
@@ -32,10 +27,8 @@ from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
@@ -45,7 +38,6 @@ from danswer.configs.app_configs import SMTP_PASS
from danswer.configs.app_configs import SMTP_PORT
from danswer.configs.app_configs import SMTP_SERVER
from danswer.configs.app_configs import SMTP_USER
from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.app_configs import WEB_DOMAIN
@@ -54,28 +46,21 @@ from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import get_session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import User
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
if user and user.role == UserRole.ADMIN:
return True
return False
USER_WHITELIST_FILE = "/home/danswer_whitelist.txt"
_user_whitelist: list[str] | None = None
def verify_auth_setting() -> None:
@@ -84,7 +69,7 @@ def verify_auth_setting() -> None:
"User must choose a valid user authentication method: "
"disabled, basic, or google_oauth"
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
def get_display_email(email: str | None, space_less: bool = False) -> str:
@@ -107,36 +92,22 @@ def user_needs_to_be_verified() -> bool:
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
def verify_email_is_invited(email: str) -> None:
whitelist = get_invited_users()
if not whitelist:
return
def get_user_whitelist() -> list[str]:
global _user_whitelist
if _user_whitelist is None:
if os.path.exists(USER_WHITELIST_FILE):
with open(USER_WHITELIST_FILE, "r") as file:
_user_whitelist = [line.strip() for line in file]
else:
_user_whitelist = []
if not email:
raise PermissionError("Email must be specified")
email_info = validate_email(email) # can raise EmailNotValidError
for email_whitelist in whitelist:
try:
# normalized emails are now being inserted into the db
# we can remove this normalization on read after some time has passed
email_info_whitelist = validate_email(email_whitelist)
except EmailNotValidError:
continue
# oddly, normalization does not include lowercasing the user part of the
# email address ... which we want to allow
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise PermissionError("User not on allowed user whitelist")
return _user_whitelist
def verify_email_in_whitelist(email: str) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
whitelist = get_user_whitelist()
if (whitelist and email not in whitelist) or not email:
raise PermissionError("User not on allowed user whitelist")
def verify_email_domain(email: str) -> None:
@@ -187,36 +158,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create: schemas.UC | UserCreate,
safe: bool = False,
request: Optional[Request] = None,
) -> User:
verify_email_is_invited(user_create.email)
) -> models.UP:
verify_email_in_whitelist(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0 or user_create.email in get_default_admin_user_emails():
if user_count == 0:
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
return user
return await super().create(user_create, safe=safe, request=request) # type: ignore
async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
@@ -234,7 +185,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_in_whitelist(account_email)
verify_email_domain(account_email)
user = await super().oauth_callback( # type: ignore
return await super().oauth_callback( # type: ignore
oauth_name=oauth_name,
access_token=access_token,
account_id=account_id,
@@ -246,34 +197,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
is_verified_by_default=is_verified_by_default,
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
await self.user_db.update(user, update_dict={"oidc_expiry": None})
# Handle case where user has used product outside of web and is now creating an account through web
if not user.has_web_login:
await self.user_db.update(
user,
update_dict={
"is_verified": is_verified_by_default,
"has_web_login": True,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True
return user
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
logger.notice(f"User {user.id} has registered.")
logger.info(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
data={"action": "create"},
@@ -283,35 +210,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
logger.info(f"User {user.id} has forgot their password. Reset token: {token}")
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
verify_email_domain(user.email)
logger.notice(
logger.info(
f"Verification requested for user {user.id}. Verification token: {token}"
)
send_user_verification_email(user.email, token)
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[User]:
user = await super().authenticate(credentials)
if user is None:
try:
user = await self.get_by_email(credentials.username)
if not user.has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
except exceptions.UserNotExists:
pass
return user
async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
@@ -328,12 +239,10 @@ cookie_transport = CookieTransport(
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
strategy = DatabaseStrategy(
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy
auth_backend = AuthenticationBackend(
name="database",
@@ -430,12 +339,6 @@ async def double_check_user(
detail="Access denied. User is not verified.",
)
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
)
return user
@@ -445,28 +348,6 @@ async def current_user(
return await double_check_user(user)
async def current_curator_or_admin_user(
user: User | None = Depends(current_user),
) -> User | None:
if DISABLE_AUTH:
return None
if not user or not hasattr(user, "role"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated or lacks role information.",
)
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
if user.role not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not a curator or admin.",
)
return user
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
if DISABLE_AUTH:
return None
@@ -474,12 +355,6 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User must be an admin to perform this action.",
detail="Access denied. User is not an admin.",
)
return user
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Danswer MIT
return []

View File

@@ -0,0 +1,212 @@
from datetime import timedelta
from typing import cast
from celery import Celery # type: ignore
from sqlalchemy.orm import Session
from danswer.background.connector_deletion import delete_connector_credential_pair
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set_paginated
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SYNC_DB_API
from danswer.db.models import DocumentSet
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
logger = setup_logger()
connection_string = build_connection_string(db_api=SYNC_DB_API)
celery_broker_url = f"sqla+{connection_string}"
celery_backend_url = f"db+{connection_string}"
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
_SYNC_BATCH_SIZE = 100
#####
# Tasks that need to be run in job queue, registered via APIs
#
# If imports from this module are needed, use local imports to avoid circular importing
#####
@build_celery_task_wrapper(name_cc_cleanup_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def cleanup_connector_credential_pair_task(
connector_id: int,
credential_id: int,
) -> int:
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
or updating the ACL"""
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
raise ValueError(
f"Cannot run deletion attempt - connector_credential_pair with Connector ID: "
f"{connector_id} and Credential ID: {credential_id} does not exist."
)
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(cc_pair)
if deletion_attempt_disallowed_reason:
raise ValueError(deletion_attempt_disallowed_reason)
try:
# The bulk of the work is in here, updates Postgres and Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
return delete_connector_credential_pair(
db_session=db_session,
document_index=document_index,
cc_pair=cc_pair,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
raise e
@build_celery_task_wrapper(name_document_set_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None:
"""For document sets marked as not up to date, sync the state from postgres
into the datastore. Also handles deletions."""
def _sync_document_batch(document_ids: list[str], db_session: Session) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# Acquires a lock on the documents so that no other process can modify them
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=document_ids, db_session=db_session
)
}
# update Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
update_requests = [
UpdateRequest(
document_ids=[document_id],
document_sets=set(document_set_map.get(document_id, [])),
)
for document_id in document_ids
]
document_index.update(update_requests=update_requests)
with Session(get_sqlalchemy_engine()) as db_session:
try:
cursor = None
while True:
document_batch, cursor = fetch_documents_for_document_set_paginated(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
last_document_id=cursor,
limit=_SYNC_BATCH_SIZE,
)
_sync_document_batch(
document_ids=[document.id for document in document_batch],
db_session=db_session,
)
if cursor is None:
break
# if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced.
document_set = cast(
DocumentSet,
get_document_set_by_id(
db_session=db_session, document_set_id=document_set_id
),
) # casting since we "know" a document set with this ID exists
if not document_set.connector_credential_pairs:
delete_document_set(
document_set_row=document_set, db_session=db_session
)
logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(
document_set_id=document_set_id, db_session=db_session
)
logger.info(f"Document set sync for '{document_set_id}' complete!")
except Exception:
logger.exception("Failed to sync document set %s", document_set_id)
raise
#####
# Periodic Tasks
#####
@celery_app.task(
name="check_for_document_sets_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task() -> None:
"""Runs periodically to check if any document sets are out of sync
Creates a task to sync the set if needed"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if not document_set.is_up_to_date:
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_live_task_not_timed_out(
latest_sync, db_session
):
logger.info(
f"Document set '{document_set.id}' is already syncing. Skipping."
)
continue
logger.info(f"Document set {document_set.id} syncing now!")
sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"check-for-document-set-sync": {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
},
}

View File

@@ -1,921 +0,0 @@
import json
from datetime import timedelta
from typing import Any
from typing import cast
import redis
from celery import Celery
from celery import signals
from celery import Task
from celery.contrib.abortable import AbortableTask # type: ignore
from celery.exceptions import SoftTimeLimitExceeded
from celery.exceptions import TaskRevokedError
from celery.signals import beat_init
from celery.signals import worker_init
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from redis import Redis
from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.connector_deletion import delete_connector_credential_pair
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import (
get_connector_credential_pair,
)
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import count_documents_by_needs_sync
from danswer.db.document import get_document
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_set_for_document
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.models import DocumentSet
from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.redis.redis_pool import RedisPool
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
redis_pool = RedisPool()
celery_app = Celery(__name__)
celery_app.config_from_object(
"danswer.background.celery.celeryconfig"
) # Load configuration from 'celeryconfig.py'
#####
# Tasks that need to be run in job queue, registered via APIs
#
# If imports from this module are needed, use local imports to avoid circular importing
#####
@build_celery_task_wrapper(name_cc_cleanup_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def cleanup_connector_credential_pair_task(
connector_id: int,
credential_id: int,
) -> int:
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
or updating the ACL"""
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
raise ValueError(
f"Cannot run deletion attempt - connector_credential_pair with Connector ID: "
f"{connector_id} and Credential ID: {credential_id} does not exist."
)
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair, db_session=db_session
)
if deletion_attempt_disallowed_reason:
raise ValueError(deletion_attempt_disallowed_reason)
try:
# The bulk of the work is in here, updates Postgres and Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
return delete_connector_credential_pair(
db_session=db_session,
document_index=document_index,
cc_pair=cc_pair,
)
except Exception as e:
task_logger.exception(
f"Failed to run connector_deletion. "
f"connector_id={connector_id} credential_id={credential_id}"
)
raise e
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
with Session(get_sqlalchemy_engine()) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
task_logger.warning(
f"ccpair not found for {connector_id} {credential_id}"
)
return
runnable_connector = instantiate_connector(
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
db_session,
)
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector
)
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
if len(doc_ids_to_remove) == 0:
task_logger.info(
f"No docs to prune from {cc_pair.connector.source} connector"
)
return
task_logger.info(
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
)
delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
except Exception as e:
task_logger.exception(
f"Failed to run pruning for connector id {connector_id}."
)
raise e
def try_generate_stale_document_sync_tasks(
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
# the fence is up, do nothing
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
return None
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
return None
task_logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
continue
if tasks_generated == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
total_tasks_generated += tasks_generated
task_logger.info(
f"All per connector generate_tasks finished. total_tasks_generated={total_tasks_generated}"
)
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
return total_tasks_generated
def try_generate_document_set_sync_tasks(
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(document_set.id)
# don't generate document set sync tasks if tasks are still pending
if r.exists(rds.fence_key):
return None
# don't generate sync tasks if we're up to date
if document_set.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rds.taskset_key)
task_logger.info(
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rds.fence_key, tasks_generated)
return tasks_generated
def try_generate_user_group_sync_tasks(
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(usergroup.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
if usergroup.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rug.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(f"generate_tasks starting. usergroup_id={usergroup.id}")
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rug.fence_key, tasks_generated)
return tasks_generated
#####
# Periodic Tasks
#####
@celery_app.task(
name="check_for_vespa_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_vespa_sync_task() -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = redis_pool.get_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
try_generate_document_set_sync_tasks(
document_set, db_session, r, lock_beat
)
# check if any user groups are not synced
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
pass
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
@celery_app.task(
name="check_for_cc_pair_deletion_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_cc_pair_deletion_task() -> None:
"""Runs periodically to check if any deletion tasks should be run"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any cc pairs are up for deletion
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
if should_kick_off_deletion_of_cc_pair(cc_pair, db_session):
task_logger.info(
f"Deleting the {cc_pair.name} connector credential pair"
)
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
),
)
@celery_app.task(
name="kombu_message_cleanup_task",
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
ctx = {}
ctx["last_processed_id"] = 0
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with Session(get_sqlalchemy_engine()) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
).scalar()
if not result:
return 0
while True:
if self.is_aborted():
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
b = kombu_message_cleanup_task_helper(ctx, db_session)
if not b:
break
db_session.commit()
if ctx["deleted"] > 0:
task_logger.info(
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
)
return ctx["deleted"]
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
"""
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
This function retrieves messages from the `kombu_message` table that are no longer visible and
older than a specified interval. It checks if the corresponding task_id exists in the
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
Args:
ctx (dict): A context dictionary containing configuration parameters such as:
- 'cleanup_age' (int): The age in days after which messages are considered old.
- 'page_limit' (int): The maximum number of messages to process in one batch.
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
- 'deleted' (int): A counter to track the number of deleted messages.
db_session (Session): The SQLAlchemy database session for executing queries.
Returns:
bool: Returns True if there are more rows to process, False if not.
"""
inspector = inspect(db_session.bind)
if not inspector:
return False
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
# We can fail silently.
if not inspector.has_table("kombu_message"):
return False
query = text(
"""
SELECT id, timestamp, payload
FROM kombu_message WHERE visible = 'false'
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
AND id > :last_processed_id
ORDER BY id
LIMIT :page_limit
"""
)
kombu_messages = db_session.execute(
query,
{
"interval_days": f"{ctx['cleanup_age']} days",
"page_limit": ctx["page_limit"],
"last_processed_id": ctx["last_processed_id"],
},
).fetchall()
if len(kombu_messages) == 0:
return False
for msg in kombu_messages:
payload = json.loads(msg[2])
task_id = payload["headers"]["id"]
# Check if task_id exists in celery_taskmeta
task_exists = db_session.execute(
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
{"task_id": task_id},
).fetchone()
# If task_id does not exist, delete the message
if not task_exists:
result = db_session.execute(
text("DELETE FROM kombu_message WHERE id = :message_id"),
{"message_id": msg[0]},
)
if result.rowcount > 0: # type: ignore
ctx["deleted"] += 1
ctx["last_processed_id"] = msg[0]
return True
@celery_app.task(
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task() -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in all_cc_pairs:
if should_prune_cc_pair(
connector=cc_pair.connector,
credential=cc_pair.credential,
db_session=db_session,
):
task_logger.info(f"Pruning the {cc_pair.connector.name} connector")
prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
)
@celery_app.task(
name="vespa_metadata_sync_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
doc = get_document(document_id, db_session)
if not doc:
return False
# document set sync
doc_sets = fetch_document_set_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
update_request = UpdateRequest(
document_ids=[document_id],
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa
document_index.update(update_requests=[update_request])
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True
@signals.task_postrun.connect
def celery_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
"""We handle this signal in order to remove completed tasks
from their respective tasksets. This allows us to track the progress of document set
and user group syncs.
This function runs after any task completes (both success and failure)
Note that this signal does not fire on a task that failed to complete and is going
to be retried.
"""
if not task:
return
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
# logger.debug(f"Result: {retval}")
if state not in READY_STATES:
return
if not task_id:
return
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r = redis_pool.get_client()
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
return
if task_id.startswith(RedisDocumentSet.PREFIX):
r = redis_pool.get_client()
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(document_set_id)
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
r = redis_pool.get_client()
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(usergroup_id)
r.srem(rug.taskset_key, task_id)
return
def monitor_connector_taskset(r: Redis) -> None:
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
task_logger.info(f"Stale documents: remaining={count} initial={initial_count}")
if count == 0:
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id is None:
task_logger.warning("could not parse document set id from {key}")
return
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"document_set_id={document_set_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
document_set = cast(
DocumentSet,
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
) # casting since we "know" a document set with this ID exists
if document_set:
if not document_set.connector_credential_pairs:
# if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!"
)
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
key = key_bytes.decode("utf-8")
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
if not usergroup_id:
task_logger.warning("Could not parse usergroup id from {key}")
return
rug = RedisUserGroup(usergroup_id)
fence_value = r.get(rug.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rug.taskset_key))
task_logger.info(
f"usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
try:
fetch_user_group = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_group"
)
except ModuleNotFoundError:
task_logger.exception(
"fetch_versioned_implementation failed to look up fetch_user_group."
)
return
user_group: UserGroup | None = fetch_user_group(
db_session=db_session, user_group_id=usergroup_id
)
if user_group:
if user_group.is_up_for_deletion:
delete_user_group = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group", "delete_user_group", noop_fallback
)
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(f" Deleted usergroup. id='{usergroup_id}'")
else:
mark_user_group_as_synced = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group", "mark_user_group_as_synced", noop_fallback
)
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
r.delete(rug.taskset_key)
r.delete(rug.fence_key)
@celery_app.task(name="monitor_vespa_sync", soft_time_limit=300)
def monitor_vespa_sync() -> None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
"""
r = redis_pool.get_client()
lock_beat = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r, db_session)
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
monitor_usergroup_taskset(key_bytes, r, db_session)
#
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
finally:
if lock_beat.owned():
lock_beat.release()
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
init_sqlalchemy_engine(POSTGRES_CELERY_BEAT_APP_NAME)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
init_sqlalchemy_engine(POSTGRES_CELERY_WORKER_APP_NAME)
# TODO(rkuo): this is singleton work that should be done on startup exactly once
# if we run multiple workers, we'll need to centralize where this cleanup happens
r = redis_pool.get_client()
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"check-for-vespa-sync": {
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
"check-for-cc-pair-deletion": {
"task": "check_for_cc_pair_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(minutes=1),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
celery_app.conf.beat_schedule.update(
{
"check-for-prune": {
"task": "check_for_prune_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)
celery_app.conf.beat_schedule.update(
{
"kombu-message-cleanup": {
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
}
)
celery_app.conf.beat_schedule.update(
{
"monitor-vespa-sync": {
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)

View File

@@ -1,299 +0,0 @@
# These are helper objects for tracking the keys we need to write in redis
import time
from abc import ABC
from abc import abstractmethod
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int):
self._id: int = id
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> int | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
try:
object_id = int(parts[2])
except ValueError:
return None
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> int | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
try:
object_id = int(parts[1])
except ValueError:
return None
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
pass
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(self._id)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisUserGroup(RedisObjectHelper):
PREFIX = "usergroup"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(self._id)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisConnectorCredentialPair(RedisObjectHelper):
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists
used to implement task prioritization.
This operation is not atomic."""
total_length = 0
for i in range(len(DanswerCeleryPriority)):
queue_name = queue
if i > 0:
queue_name += CELERY_SEPARATOR
queue_name += str(i)
length = r.llen(queue_name)
total_length += cast(int, length)
return total_length

View File

@@ -1,9 +0,0 @@
"""Entry point for running celery worker / celery beat."""
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
celery_app = fetch_versioned_implementation(
"danswer.background.celery.celery_app", "celery_app"
)

View File

@@ -1,153 +1,23 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.interfaces import BaseConnector
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.engine import get_db_current_time
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import TaskQueueState
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_deletion_status(
def get_deletion_status(
connector_id: int, credential_id: int, db_session: Session
) -> TaskQueueState | None:
) -> DeletionAttemptSnapshot | None:
cleanup_task_name = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
)
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
task_state = get_latest_task(task_name=cleanup_task_name, db_session=db_session)
def get_deletion_attempt_snapshot(
connector_id: int, credential_id: int, db_session: Session
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
if not deletion_task:
if not task_state:
return None
return DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_task.status,
status=task_state.status,
)
def should_kick_off_deletion_of_cc_pair(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return False
if check_deletion_attempt_is_allowed(cc_pair, db_session):
return False
deletion_task = _get_deletion_status(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
db_session=db_session,
)
if deletion_task and check_task_is_live_and_not_timed_out(
deletion_task,
db_session,
# 1 hour timeout
timeout=60 * 60,
):
return False
return True
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:
if not connector.prune_freq:
return False
pruning_task_name = name_cc_prune_task(
connector_id=connector.id, credential_id=credential.id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
current_db_time = get_db_current_time(db_session)
if not last_pruning_task:
time_since_initialization = current_db_time - connector.time_created
if time_since_initialization.total_seconds() >= connector.prune_freq:
return True
return False
if not ALLOW_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
return False
if not last_pruning_task.start_time:
return False
time_since_last_pruning = current_db_time - last_pruning_task.start_time
return time_since_last_pruning.total_seconds() >= connector.prune_freq
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
return {doc.id for doc in doc_batch}
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs
"""
all_connector_doc_ids: set[str] = set()
doc_batch_generator = None
if isinstance(runnable_connector, IdConnector):
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
elif isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state()
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
if doc_batch_generator:
doc_batch_processing_func = document_batch_to_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids

View File

@@ -1,35 +0,0 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
from danswer.configs.constants import DanswerCeleryPriority
CELERY_SEPARATOR = ":"
CELERY_PASSWORD_PART = ""
if REDIS_PASSWORD:
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
# example celery_broker_url: "redis://:password@localhost:6379/15"
broker_url = (
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
)
result_backend = (
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
)
# NOTE: prefetch 4 is significantly faster than prefetch 1
# however, prefetching is bad when tasks are lengthy as those tasks
# can stall other tasks.
worker_prefetch_multiplier = 4
broker_transport_options = {
"priority_steps": list(range(len(DanswerCeleryPriority))),
"sep": CELERY_SEPARATOR,
"queue_order_strategy": "priority",
}
task_default_priority = DanswerCeleryPriority.MEDIUM
task_acks_late = True

View File

@@ -10,6 +10,8 @@ are multiple connector / credential pairs that have indexed it
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
import time
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
@@ -22,8 +24,10 @@ from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document_connector_cnts
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.document_set import (
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit,
)
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import ConnectorCredentialPair
@@ -31,17 +35,13 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
_DELETION_BATCH_SIZE = 1000
def delete_connector_credential_pair_batch(
def _delete_connector_credential_pair_batch(
document_ids: list[str],
connector_id: int,
credential_id: int,
@@ -78,37 +78,25 @@ def delete_connector_credential_pair_batch(
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
]
# maps document id to list of document set names
new_doc_sets_for_documents: dict[str, set[str]] = {
document_id_and_document_set_names_tuple[0]: set(
document_id_and_document_set_names_tuple[1]
)
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
db_session=db_session,
document_ids=document_ids_to_update,
)
}
# determine future ACLs for documents in batch
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
# update Vespa
logger.debug(f"Updating documents: {document_ids_to_update}")
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
document_sets=new_doc_sets_for_documents[document_id],
)
for document_id, access in access_for_documents.items()
]
logger.debug(f"Updating documents: {document_ids_to_update}")
document_index.update(update_requests=update_requests)
# clean up Postgres
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids_to_update,
@@ -120,6 +108,48 @@ def delete_connector_credential_pair_batch(
db_session.commit()
def cleanup_synced_entities(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> None:
"""Updates the document sets associated with the connector / credential pair,
then relies on the document set sync script to kick off Celery jobs which will
sync these updates to Vespa.
Waits until the document sets are synced before returning."""
logger.info(f"Cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'")
document_sets_ids_to_sync = list(
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
)
db_session.commit()
# wait till all document sets are synced before continuing
while True:
all_synced = True
document_sets = get_document_sets_by_ids(
db_session=db_session, document_set_ids=document_sets_ids_to_sync
)
for document_set in document_sets:
if not document_set.is_up_to_date:
all_synced = False
if all_synced:
break
# wait for 30 seconds before checking again
db_session.commit() # end transaction
logger.info(
f"Document sets '{document_sets_ids_to_sync}' not synced yet, waiting 30s"
)
time.sleep(30)
logger.info(
f"Finished cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'"
)
def delete_connector_credential_pair(
db_session: Session,
document_index: DocumentIndex,
@@ -139,7 +169,7 @@ def delete_connector_credential_pair(
if not documents:
break
delete_connector_credential_pair_batch(
_delete_connector_credential_pair_batch(
document_ids=[document.id for document in documents],
connector_id=connector_id,
credential_id=credential_id,
@@ -147,32 +177,17 @@ def delete_connector_credential_pair(
)
num_docs_deleted += len(documents)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair.id,
)
# Clean up document sets / access information from Postgres
# and sync these updates to Vespa
# TODO: add user group cleanup with `fetch_versioned_implementation`
cleanup_synced_entities(cc_pair, db_session)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
# clean up the rest of the related Postgres entities
delete_index_attempts(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id,
@@ -184,11 +199,11 @@ def delete_connector_credential_pair(
connector_id=connector_id,
)
if not connector or not len(connector.credentials):
logger.info("Found no credentials left for connector, deleting connector")
logger.debug("Found no credentials left for connector, deleting connector")
db_session.delete(connector)
db_session.commit()
logger.notice(
logger.info(
"Successfully deleted connector_credential_pair with connector_id:"
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
)

View File

@@ -41,12 +41,6 @@ def _initializer(
return func(*args, **kwargs)
def _run_in_process(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> None:
_initializer(func, args, kwargs)
@dataclass
class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`"""
@@ -111,15 +105,13 @@ class SimpleJobClient:
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
self._cleanup_completed_jobs()
if len(self.jobs) >= self.n_workers:
logger.debug(
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
)
logger.debug("No available workers to run job")
return None
job_id = self.job_id_counter
self.job_id_counter += 1
process = Process(target=_run_in_process, args=(func, args), daemon=True)
process = Process(target=_initializer(func=func, args=args), daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

View File

@@ -6,22 +6,27 @@ from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.connector_deletion import (
_delete_connector_credential_pair_batch,
)
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
from danswer.background.indexing.tracer import DanswerTracer
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from danswer.configs.app_configs import INDEXING_TRACER_INTERVAL
from danswer.configs.app_configs import DISABLE_DOCUMENT_CLEANUP
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
from danswer.connectors.connector_runner import ConnectorRunner
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.connectors.models import InputType
from danswer.db.connector import disable_connector
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
@@ -32,19 +37,16 @@ from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
def _get_connector_runner(
def _get_document_generator(
db_session: Session,
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
) -> ConnectorRunner:
) -> tuple[GenerateDocumentsOutput, bool]:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
@@ -52,31 +54,47 @@ def _get_connector_runner(
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
task = attempt.connector_credential_pair.connector.input_type
task = attempt.connector.input_type
try:
runnable_connector = instantiate_connector(
attempt.connector_credential_pair.connector.source,
runnable_connector, new_credential_json = instantiate_connector(
attempt.connector.source,
task,
attempt.connector_credential_pair.connector.connector_specific_config,
attempt.connector_credential_pair.credential,
db_session,
attempt.connector.connector_specific_config,
attempt.credential.credential_json,
)
if new_credential_json is not None:
backend_update_credential_json(
attempt.credential, new_credential_json, db_session
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_credential_pair.connector.id,
credential_id=attempt.connector_credential_pair.credential.id,
status=ConnectorCredentialPairStatus.PAUSED,
)
disable_connector(attempt.connector.id, db_session)
raise e
return ConnectorRunner(
connector=runnable_connector, time_range=(start_time, end_time)
)
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
is_listing_complete = True
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
logger.info(f"Polling for updates between {start_time} and {end_time}")
doc_batch_generator = runnable_connector.poll_source(
start=start_time.timestamp(), end=end_time.timestamp()
)
is_listing_complete = False
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator, is_listing_complete
def _run_indexing(
@@ -90,62 +108,46 @@ def _run_indexing(
"""
start_time = time.time()
search_settings = index_attempt.search_settings
index_name = search_settings.index_name
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary = search_settings.status == IndexModelStatus.PRESENT
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
# Indexing is only done into one index at a time
document_index = get_default_document_index(
primary_index_name=index_name, secondary_index_name=None
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
embedding_model = DefaultIndexingEmbedder(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE),
or (db_embedding_model.status == IndexModelStatus.FUTURE),
db_session=db_session,
)
db_cc_pair = index_attempt.connector_credential_pair
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
earliest_index_time = (
db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0
)
db_connector = index_attempt.connector
db_credential = index_attempt.credential
last_successful_index_time = (
earliest_index_time
0.0
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
earliest_index=earliest_index_time,
search_settings=index_attempt.search_settings,
embedding_model=index_attempt.embedding_model,
db_session=db_session,
)
)
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = DanswerTracer()
tracer.start()
tracer.snap()
index_attempt_md = IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
)
batch_num = 0
net_doc_change = 0
document_count = 0
chunk_count = 0
@@ -158,37 +160,29 @@ def _run_indexing(
source_type=db_connector.source,
)
):
window_start = max(
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
doc_batch_generator, is_listing_complete = _get_document_generator(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
)
try:
window_start = max(
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
connector_runner = _get_connector_runner(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
)
all_connector_doc_ids: set[str] = set()
tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in connector_runner.run():
for doc_batch in doc_batch_generator:
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
db_session.refresh(db_connector)
if (
(
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
and search_settings.status != IndexModelStatus.FUTURE
)
# if it's deleting, we don't care if this is a secondary index
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
db_connector.disabled
and db_embedding_model.status != IndexModelStatus.FUTURE
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
@@ -198,30 +192,17 @@ def _run_indexing(
# Likely due to user manually disabling it or model swap
raise RuntimeError("Index Attempt was canceled")
batch_description = []
for doc in doc_batch:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
for section in doc.sections:
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(
f"Document size: doc='{doc.to_short_descriptor()}' "
f"size={doc_size} "
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
)
logger.debug(f"Indexing batch of documents: {batch_description}")
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch,
index_attempt_metadata=index_attempt_md,
logger.debug(
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
)
batch_num += 1
new_docs, total_batch_chunks = indexing_pipeline(
documents=doc_batch,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
),
)
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
@@ -243,16 +224,38 @@ def _run_indexing(
docs_removed_from_index=0,
)
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.debug(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
if is_listing_complete and not DISABLE_DOCUMENT_CLEANUP:
# clean up all documents from the index that have not been returned from the connector
all_indexed_document_ids = {
d.id
for d in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
)
tracer.snap()
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
}
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids
)
logger.debug(
f"Cleaning up {len(doc_ids_to_remove)} documents that are not contained in the newest connector state"
)
# delete docs from cc-pair and receive the number of completely deleted docs in return
_delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=db_connector.id,
credential_id=db_credential.id,
document_index=document_index,
)
update_docs_indexed(
db_session=db_session,
index_attempt=index_attempt,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
docs_removed_from_index=len(doc_ids_to_remove),
)
run_end_dt = window_end
if is_primary:
@@ -260,11 +263,12 @@ def _run_indexing(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
except Exception as e:
logger.exception(
logger.info(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
@@ -276,7 +280,7 @@ def _run_indexing(
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or db_connector.disabled
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
@@ -288,74 +292,34 @@ def _run_indexing(
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
connector_id=index_attempt.connector.id,
credential_id=index_attempt.credential.id,
attempt_status=IndexingStatus.FAILED,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
break
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.debug("Memory tracer stopped.")
if (
index_attempt_md.num_exceptions > 0
and index_attempt_md.num_exceptions >= batch_num
):
mark_attempt_failed(
index_attempt,
db_session,
failure_reason="All batches exceptioned.",
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector_credential_pair.connector.id,
credential_id=index_attempt.connector_credential_pair.credential.id,
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
elapsed_time = time.time() - start_time
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt, db_session)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt, db_session)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
f"elapsed={elapsed_time:.2f}s"
)
mark_attempt_succeeded(index_attempt, db_session)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.SUCCESS,
run_dt=run_end_dt,
)
logger.info(
f"Indexed or refreshed {document_count} total documents for a total of {chunk_count} indexed chunks"
)
logger.info(
f"Connector successfully finished, elapsed time: {time.time() - start_time} seconds"
)
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
# make sure that the index attempt can't change in between checking the
@@ -368,7 +332,6 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
@@ -379,27 +342,29 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
)
# only commit once, to make sure this all happens in a single transaction
mark_attempt_in_progress(attempt, db_session)
mark_attempt_in_progress__no_commit(attempt)
is_primary = attempt.embedding_model.status == IndexModelStatus.PRESENT
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector.id,
credential_id=attempt.credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
)
else:
db_session.commit()
return attempt
def run_indexing_entrypoint(
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
) -> None:
def run_indexing_entrypoint(index_attempt_id: int) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
with Session(get_sqlalchemy_engine()) as db_session:
# make sure that it is valid to run this indexing attempt + mark it
@@ -407,19 +372,17 @@ def run_indexing_entrypoint(
attempt = _prepare_index_attempt(db_session, index_attempt_id)
logger.info(
f"Indexing starting: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
f"Running indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
_run_indexing(db_session, attempt)
logger.info(
f"Indexing finished: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")

View File

@@ -1,77 +0,0 @@
import tracemalloc
from danswer.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class DanswerTracer:
def __init__(self) -> None:
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
def start(self) -> None:
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
def stop(self) -> None:
tracemalloc.stop()
def snap(self) -> None:
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
tracemalloc.Filter(
False, "<frozen importlib._bootstrap>"
), # Exclude importlib
tracemalloc.Filter(
False, "<frozen importlib._bootstrap_external>"
), # Exclude external importlib
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def log_snapshot(self, numEntries: int) -> None:
if not self.snapshot:
return
stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer snap: {s}")
for line in s.traceback:
logger.debug(f"* {line}")
@staticmethod
def log_diff(
snap_current: tracemalloc.Snapshot,
snap_previous: tracemalloc.Snapshot,
numEntries: int,
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:
return
DanswerTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
def log_first_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_first:
return
DanswerTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)

View File

@@ -22,15 +22,6 @@ def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}"
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
if not connector_id or not credential_id:
task_name = "prune_connector_credential_pair"
return task_name
T = TypeVar("T", bound=Callable)
@@ -93,16 +84,9 @@ def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA
kwargs_for_build_name = kwargs or {}
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
with Session(get_sqlalchemy_engine()) as db_session:
# register_task must come before fn = apply_async or else the task
# might run mark_task_start (and crash) before the task row exists
db_task = register_task(task_name, db_session)
# mark the task as started
task = fn(args, kwargs, *other_args, **other_kwargs)
# we update the celery task id for diagnostic purposes
# but it isn't currently used by any code
db_task.task_id = task.id
db_session.commit()
register_task(task.id, task_name, db_session)
return task

View File

@@ -16,33 +16,27 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import get_last_attempt
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Connector
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MODEL_SERVER_PORT
@@ -59,68 +53,41 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
def _should_create_new_indexing(
cc_pair: ConnectorCredentialPair,
connector: Connector,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
model: EmbeddingModel,
secondary_index_building: bool,
db_session: Session,
) -> bool:
connector = cc_pair.connector
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
return False
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
return False
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if connector.id == 0: # Ingestion API
return False
if model.status == IndexModelStatus.FUTURE and not last_index:
if connector.id == 0: # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if not cc_pair.status.is_active() or connector.id == 0:
# If the connector is disabled, don't index
# NOTE: during an embedding model switch over, we ignore this
# and index the disabled connectors as well (which is why this if
# statement is below the first condition above)
if connector.disabled:
return False
if not last_index:
return True
if connector.refresh_freq is None:
return False
if not last_index:
return True
# Only one scheduled/ongoing job per connector at a time
# this prevents cases where
# (1) the "latest" index_attempt is scheduled so we show
# that in the UI despite another index_attempt being in-progress
# (2) multiple scheduled index_attempts at a time
if (
last_index.status == IndexingStatus.NOT_STARTED
or last_index.status == IndexingStatus.IN_PROGRESS
):
# Only one scheduled job per connector at a time
# Can schedule another one if the current one is already running however
# Because the currently running one will not be until the latest time
# Note, this last index is for the given embedding model
if last_index.status == IndexingStatus.NOT_STARTED:
return False
current_db_time = get_db_current_time(db_session)
@@ -128,20 +95,41 @@ def _should_create_new_indexing(
return time_since_index.total_seconds() >= connector.refresh_freq
def _is_indexing_job_marked_as_finished(index_attempt: IndexAttempt | None) -> bool:
if index_attempt is None:
return False
return (
index_attempt.status == IndexingStatus.FAILED
or index_attempt.status == IndexingStatus.SUCCESS
)
def _mark_run_failed(
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
) -> None:
"""Marks the `index_attempt` row as failed + updates the `
connector_credential_pair` to reflect that the run failed"""
logger.warning(
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
)
mark_attempt_failed(
index_attempt=index_attempt,
db_session=db_session,
failure_reason=failure_reason,
)
if (
index_attempt.connector_id is not None
and index_attempt.credential_id is not None
and index_attempt.embedding_model.status == IndexModelStatus.PRESENT
):
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector_id,
credential_id=index_attempt.credential_id,
attempt_status=IndexingStatus.FAILED,
)
"""Main funcs"""
@@ -154,7 +142,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
3. There is not already an ongoing indexing attempt for this pair
"""
with Session(get_sqlalchemy_engine()) as db_session:
ongoing: set[tuple[int | None, int]] = set()
ongoing: set[tuple[int | None, int | None, int]] = set()
for attempt_id in existing_jobs:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
@@ -167,43 +155,52 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
continue
ongoing.add(
(
attempt.connector_credential_pair_id,
attempt.search_settings_id,
attempt.connector_id,
attempt.credential_id,
attempt.embedding_model_id,
)
)
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
embedding_models = [get_current_db_embedding_model(db_session)]
secondary_embedding_model = get_secondary_db_embedding_model(db_session)
if secondary_embedding_model is not None:
embedding_models.append(secondary_embedding_model)
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
all_connectors = fetch_connectors(db_session)
for connector in all_connectors:
for association in connector.credentials:
for model in embedding_models:
credential = association.credential
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in all_connector_credential_pairs:
for search_settings_instance in search_settings:
# Check if there is an ongoing indexing attempt for this connector credential pair
if (cc_pair.id, search_settings_instance.id) in ongoing:
continue
# Check if there is an ongoing indexing attempt for this connector + credential pair
if (connector.id, credential.id, model.id) in ongoing:
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if not _should_create_new_indexing(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
db_session=db_session,
):
continue
last_attempt = get_last_attempt(
connector.id, credential.id, model.id, db_session
)
if not _should_create_new_indexing(
connector=connector,
last_index=last_attempt,
model=model,
secondary_index_building=len(embedding_models) > 1,
db_session=db_session,
):
continue
create_index_attempt(
cc_pair.id, search_settings_instance.id, db_session
)
create_index_attempt(
connector.id, credential.id, model.id, db_session
)
# CC-Pair will have the status that it should for the primary index
# Will be re-sync-ed once the indices are swapped
if model.status == IndexModelStatus.PRESENT:
update_connector_credential_pair(
db_session=db_session,
connector_id=connector.id,
credential_id=credential.id,
attempt_status=IndexingStatus.NOT_STARTED,
)
def cleanup_indexing_jobs(
@@ -220,12 +217,10 @@ def cleanup_indexing_jobs(
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if not job.done() and not _is_indexing_job_marked_as_finished(
index_attempt
):
continue
if job.status == "error":
logger.error(job.exception())
@@ -297,28 +292,24 @@ def kickoff_indexing_jobs(
# Don't include jobs waiting in the Dask queue that just haven't started running
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
with Session(engine) as db_session:
# get_not_started_index_attempts orders its returned results from oldest to newest
# we must process attempts in a FIFO manner to prevent connector starvation
new_indexing_attempts = [
(attempt, attempt.search_settings)
(attempt, attempt.embedding_model)
for attempt in get_not_started_index_attempts(db_session)
if attempt.id not in existing_jobs
]
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
if not new_indexing_attempts:
return existing_jobs
indexing_attempt_count = 0
for attempt, search_settings in new_indexing_attempts:
for attempt, embedding_model in new_indexing_attempts:
use_secondary_index = (
search_settings.status == IndexModelStatus.FUTURE
if search_settings is not None
embedding_model.status == IndexModelStatus.FUTURE
if embedding_model is not None
else False
)
if attempt.connector_credential_pair.connector is None:
if attempt.connector is None:
logger.warning(
f"Skipping index attempt as Connector has been deleted: {attempt}"
)
@@ -327,7 +318,7 @@ def kickoff_indexing_jobs(
attempt, db_session, failure_reason="Connector is null"
)
continue
if attempt.connector_credential_pair.credential is None:
if attempt.credential is None:
logger.warning(
f"Skipping index attempt as Credential has been deleted: {attempt}"
)
@@ -339,73 +330,39 @@ def kickoff_indexing_jobs(
if use_secondary_index:
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
run_indexing_entrypoint, attempt.id, pure=False
)
else:
run = client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
)
run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
if run:
if indexing_attempt_count == 0:
logger.info(
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
)
indexing_attempt_count += 1
secondary_str = " (secondary index)" if use_secondary_index else ""
secondary_str = "(secondary index) " if use_secondary_index else ""
logger.info(
f"Indexing dispatched{secondary_str}: "
f"attempt_id={attempt.id} "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.credential_id}'"
f"Kicked off {secondary_str}"
f"indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
existing_jobs_copy[attempt.id] = run
if indexing_attempt_count > 0:
logger.info(
f"Indexing dispatch results: "
f"initial_pending={len(new_indexing_attempts)} "
f"started={indexing_attempt_count} "
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
)
return existing_jobs_copy
def update_loop(
delay: int = 10,
num_workers: int = NUM_INDEXING_WORKERS,
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
) -> None:
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
db_embedding_model = get_current_db_embedding_model(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
@@ -420,7 +377,7 @@ def update_loop(
silence_logs=logging.ERROR,
)
cluster_secondary = LocalCluster(
n_workers=num_secondary_workers,
n_workers=num_workers,
threads_per_worker=1,
silence_logs=logging.ERROR,
)
@@ -430,18 +387,23 @@ def update_loop(
client_primary.register_worker_plugin(ResourceLogger())
else:
client_primary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
client_secondary = SimpleJobClient(n_workers=num_workers)
existing_jobs: dict[int, Future | SimpleJob] = {}
with Session(engine) as db_session:
# Previous version did not always clean up cc-pairs well leaving some connectors undeleteable
# This ensures that bad states get cleaned up
mark_all_in_progress_cc_pairs_failed(db_session)
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.debug(f"Running update, current UTC time: {start_time_utc}")
logger.info(f"Running update, current UTC time: {start_time_utc}")
if existing_jobs:
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
logger.debug(
logger.info(
"Found existing indexing jobs: "
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
)
@@ -464,10 +426,7 @@ def update_loop(
def update__main() -> None:
set_is_ee_based_on_env_variable()
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
logger.notice("Starting indexing service")
logger.info("Starting Indexing Loop")
update_loop()

View File

@@ -1,4 +1,5 @@
import re
from collections.abc import Sequence
from typing import cast
from sqlalchemy.orm import Session
@@ -8,46 +9,53 @@ from danswer.chat.models import LlmDoc
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inference_section.center_chunk.document_id,
document_id=inf_chunk.document_id,
# This one is using the combined content of all the chunks of the section
# In default settings, this is the same as just the content of base chunk
content=inference_section.combined_content,
blurb=inference_section.center_chunk.blurb,
semantic_identifier=inference_section.center_chunk.semantic_identifier,
source_type=inference_section.center_chunk.source_type,
metadata=inference_section.center_chunk.metadata,
updated_at=inference_section.center_chunk.updated_at,
link=inference_section.center_chunk.source_links[0]
if inference_section.center_chunk.source_links
else None,
source_links=inference_section.center_chunk.source_links,
content=inf_chunk.combined_content,
blurb=inf_chunk.blurb,
semantic_identifier=inf_chunk.semantic_identifier,
source_type=inf_chunk.source_type,
metadata=inf_chunk.metadata,
updated_at=inf_chunk.updated_at,
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
source_links=inf_chunk.source_links,
)
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
) -> dict[str, int]:
order_mapping = {}
current = 1 if one_indexed else 0
for chunk in chunks:
if chunk.document_id not in order_mapping:
order_mapping[chunk.document_id] = current
current += 1
return order_mapping
def create_chat_chain(
chat_session_id: int,
db_session: Session,
prefetch_tool_calls: bool = True,
# Optional id at which we finish processing
stop_at_message_id: int | None = None,
) -> tuple[ChatMessage, list[ChatMessage]]:
"""Build the linear chain of messages without including the root message"""
mainline_messages: list[ChatMessage] = []
all_chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session_id,
user_id=None,
db_session=db_session,
skip_permission_check=True,
prefetch_tool_calls=prefetch_tool_calls,
)
id_to_msg = {msg.id: msg for msg in all_chat_messages}
@@ -63,12 +71,7 @@ def create_chat_chain(
current_message: ChatMessage | None = root_message
while current_message is not None:
child_msg = current_message.latest_child_message
# Break if at the end of the chain
# or have reached the `final_id` of the submitted message
if not child_msg or (
stop_at_message_id and current_message.id == stop_at_message_id
):
if not child_msg:
break
current_message = id_to_msg.get(child_msg)

View File

@@ -1,24 +0,0 @@
input_prompts:
- id: -5
prompt: "Elaborate"
content: "Elaborate on the above, give me a more in depth explanation."
active: true
is_public: true
- id: -4
prompt: "Reword"
content: "Help me rewrite the following politely and concisely for professional communication:\n"
active: true
is_public: true
- id: -3
prompt: "Email"
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
active: true
is_public: true
- id: -2
prompt: "Debug"
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
active: true
is_public: true

View File

@@ -1,20 +1,18 @@
from typing import cast
import yaml
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.chat import get_prompt_by_name
from danswer.db.chat import upsert_persona
from danswer.db.chat import upsert_prompt
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
from danswer.db.models import Prompt as PromptDBModel
from danswer.db.models import Tool as ToolDBModel
from danswer.db.persona import get_prompt_by_name
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
@@ -52,7 +50,7 @@ def load_personas_from_yaml(
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
doc_sets: list[DocumentSetDBModel] | None = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
@@ -60,51 +58,27 @@ def load_personas_from_yaml(
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
if not doc_sets:
doc_sets = None
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
if not prompt_set_names:
prompts: list[PromptDBModel | None] | None = None
else:
prompts = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
if not prompts:
prompts = None
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona)
.filter(Persona.name == persona["name"])
.first()
)
upsert_persona(
user=None,
# Negative to not conflict with existing personas
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
@@ -114,52 +88,20 @@ def load_personas_from_yaml(
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
llm_model_provider_override=None,
llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
prompts=cast(list[PromptDBModel] | None, prompts),
document_sets=doc_sets,
default_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_input_prompts = data.get("input_prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
def load_chat_yamls(
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(prompt_yaml)
load_personas_from_yaml(personas_yaml)
load_input_prompts_from_yaml(input_prompts_yaml)

View File

@@ -1,6 +1,5 @@
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
@@ -10,7 +9,6 @@ from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.custom.base_tool_types import ToolResultType
class LlmDoc(BaseModel):
@@ -36,51 +34,19 @@ class QADocsResponse(RetrievalDocs):
applied_time_cutoff: datetime | None
recency_bias_multiplier: float
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
initial_dict["applied_time_cutoff"] = (
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
)
return initial_dict
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
class StreamStopInfo(BaseModel):
stop_reason: StreamStopReason
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
data["stop_reason"] = self.stop_reason.name
return data
# Second chunk of info for streaming QA
class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class RelevanceAnalysis(BaseModel):
relevant: bool
content: str | None = None
class SectionRelevancePiece(RelevanceAnalysis):
"""LLM analysis mapped to an Inference Section"""
document_id: str
chunk_id: int # ID of the center chunk for a given inference section
class DocumentRelevance(BaseModel):
"""Contains all relevance information for a given search"""
relevance_summaries: dict[str, RelevanceAnalysis]
class DanswerAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer
@@ -93,14 +59,8 @@ class CitationInfo(BaseModel):
document_id: str
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
class StreamingError(BaseModel):
error: str
stack_trace: str | None = None
class DanswerQuote(BaseModel):
@@ -146,20 +106,13 @@ class ImageGenerationDisplay(BaseModel):
file_ids: list[str]
class CustomToolResponse(BaseModel):
response: ToolResultType
tool_name: str
AnswerQuestionPossibleReturn = (
DanswerAnswerPiece
| DanswerQuotes
| CitationInfo
| DanswerContexts
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo
)

View File

@@ -5,7 +5,7 @@ personas:
# this is for DanswerBot to use when tagged in a non-configured channel
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
- id: 0
name: "Knowledge"
name: "Danswer"
description: >
Assistant with access to documents from your Connected Sources.
# Default Prompt objects attached to the persona, see prompts.yaml
@@ -17,7 +17,7 @@ personas:
num_chunks: 10
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
# if the chunk is useful or not towards the latest user query
# This feature can be overriden for all personas via DISABLE_LLM_DOC_RELEVANCE env variable
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
llm_relevance_filter: true
# Enable/Disable usage of the LLM to extract query time filters including source type and time range filters
llm_filter_extraction: true
@@ -37,15 +37,12 @@ personas:
# - "Engineer Onboarding"
# - "Benefits"
document_sets: []
icon_shape: 23013
icon_color: "#6FB1FF"
display_priority: 1
is_visible: true
- id: 1
name: "General"
name: "GPT"
description: >
Assistant with no access to documents. Chat with just the Large Language Model.
Assistant with no access to documents. Chat with just the Language Model.
prompts:
- "OnlyLLM"
num_chunks: 0
@@ -53,10 +50,7 @@ personas:
llm_filter_extraction: true
recency_bias: "auto"
document_sets: []
icon_shape: 50910
icon_color: "#FF6F6F"
display_priority: 0
is_visible: true
- id: 2
name: "Paraphrase"
@@ -69,25 +63,3 @@ personas:
llm_filter_extraction: true
recency_bias: "auto"
document_sets: []
icon_shape: 45519
icon_color: "#6FFF8D"
display_priority: 2
is_visible: false
- id: 3
name: "Art"
description: >
Assistant for generating images based on descriptions.
prompts:
- "ImageGeneration"
num_chunks: 0
llm_relevance_filter: false
llm_filter_extraction: false
recency_bias: "no_decay"
document_sets: []
icon_shape: 234124
icon_color: "#9B59B6"
image_generation: true
display_priority: 3
is_visible: true

View File

@@ -1,4 +1,3 @@
import traceback
from collections.abc import Callable
from collections.abc import Iterator
from functools import partial
@@ -8,20 +7,15 @@ from sqlalchemy.orm import Session
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
@@ -29,16 +23,13 @@ from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_db_search_doc_by_id
from danswer.db.chat import get_doc_query_identifiers_from_model
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
@@ -51,47 +42,25 @@ from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.retrieval.search_runner import inference_documents_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.factory import get_tool_cls
from danswer.tools.force import ForceUseTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
from danswer.tools.internet_search.internet_search_tool import (
internet_search_response_to_search_docs,
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
@@ -124,21 +93,14 @@ def _handle_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
selected_search_docs: list[DbSearchDoc] | None,
dedupe_docs: bool = False,
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
response_sumary = cast(SearchResponseSummary, packet.response)
dropped_inds = None
if not selected_search_docs:
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
deduped_docs = top_docs
if dedupe_docs:
deduped_docs, dropped_inds = dedupe_documents(top_docs)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in deduped_docs
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
for top_doc in top_docs
]
else:
reference_db_search_docs = selected_search_docs
@@ -158,81 +120,35 @@ def _handle_search_tool_response_summary(
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
),
reference_db_search_docs,
dropped_inds,
)
def _handle_internet_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
internet_search_response = cast(InternetSearchResponse, packet.response)
server_search_docs = internet_search_response_to_search_docs(
internet_search_response
)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in server_search_docs
]
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
return (
QADocsResponse(
rephrased_query=internet_search_response.revised_query,
top_documents=response_docs,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.SEMANTIC,
applied_source_filters=[],
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
),
reference_db_search_docs,
)
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
) -> ForceUseTool:
internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
)
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override and tool_name == SearchTool._NAME
else None
)
if new_msg_req.file_descriptors:
# If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
should_force_search = any(
[
def _check_should_force_search(
new_msg_req: CreateChatMessageRequest,
) -> ForceUseTool | None:
if (
new_msg_req.query_override
or (
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
DISABLE_LLM_CHOOSE_SEARCH,
]
)
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
)
or new_msg_req.search_doc_ids
):
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override
else None
)
# if we are using selected docs, just put something here so the Tool doesn't need
# to build its own args via an LLM call
if new_msg_req.search_doc_ids:
args = {"query": new_msg_req.message}
if should_force_search:
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
return ForceUseTool(
tool_name=SearchTool.name(),
args=args,
)
return None
ChatPacket = (
@@ -243,8 +159,6 @@ ChatPacket = (
| DanswerAnswerPiece
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| MessageResponseIDInfo
)
ChatPacketStream = Iterator[ChatPacket]
@@ -260,21 +174,16 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
# user message (e.g. this can only be used for the chat-seeding flow).
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
new_msg_req.chunks_below = 0
"""
try:
user_id = user.id if user is not None else None
@@ -289,18 +198,7 @@ def stream_chat_message_objects(
parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids
retrieval_options = new_msg_req.retrieval_options
alternate_assistant_id = new_msg_req.alternate_assistant_id
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
persona = get_persona_by_id(
alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
else:
persona = chat_session.persona
persona = chat_session.persona
prompt_id = new_msg_req.prompt_id
if prompt_id is None and persona.prompts:
@@ -312,28 +210,20 @@ def stream_chat_message_objects(
)
try:
llm, fast_llm = get_llms_for_persona(
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
llm = get_llm_for_persona(
persona, new_msg_req.llm_override or chat_session.llm_override
)
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
llm_provider = llm.config.model_provider
llm_model_name = llm.config.model_name
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
llm_tokenizer = get_default_llm_tokenizer()
llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
search_settings = get_current_search_settings(db_session)
embedding_model = get_current_db_embedding_model(db_session)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
primary_index_name=embedding_model.index_name, secondary_index_name=None
)
# Every chat Session begins with an empty root message
@@ -350,16 +240,7 @@ def stream_chat_message_objects(
else:
parent_message = root_message
user_message = None
if new_msg_req.regenerate:
final_msg, history_msgs = create_chat_chain(
stop_at_message_id=parent_id,
chat_session_id=chat_session_id,
db_session=db_session,
)
elif not use_existing_user_message:
if not use_existing_user_message:
# Create new message at the right place in the tree and update the parent's child pointer
# Don't commit yet until we verify the chat message chain
user_message = create_new_chat_message(
@@ -369,7 +250,10 @@ def stream_chat_message_objects(
message=message_text,
token_count=len(llm_tokenizer_encode_func(message_text)),
message_type=MessageType.USER,
files=None, # Need to attach later for optimization to only load files once in parallel
files=[
{"id": str(file_id), "type": ChatFileType.IMAGE}
for file_id in new_msg_req.file_ids
],
db_session=db_session,
commit=False,
)
@@ -384,8 +268,8 @@ def stream_chat_message_objects(
"Be sure to update the chat pointers before calling this."
)
# NOTE: do not commit user message - it will be committed when the
# assistant message is successfully generated
# Save now to save the latest chat message
db_session.commit()
else:
# re-create linear history of messages
final_msg, history_msgs = create_chat_chain(
@@ -398,36 +282,14 @@ def stream_chat_message_objects(
"when the last message is not a user message."
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
)
files = load_all_chat_files(history_msgs, new_msg_req.file_ids, db_session)
latest_query_files = [
file
for file in files
if file.file_id in [f["id"] for f in new_msg_req.file_descriptors]
file for file in files if file.file_id in new_msg_req.file_ids
]
if user_message:
attach_files_to_chat_message(
chat_message=user_message,
files=[
new_file.to_file_descriptor() for new_file in latest_query_files
],
db_session=db_session,
commit=False,
)
selected_db_search_docs = None
selected_sections: list[InferenceSection] | None = None
selected_llm_docs: list[LlmDoc] | None = None
if reference_doc_ids:
identifier_tuples = get_doc_query_identifiers_from_model(
search_doc_ids=reference_doc_ids,
@@ -437,8 +299,8 @@ def stream_chat_message_objects(
)
# Generates full documents currently
# May extend to use sections instead in the future
selected_sections = inference_sections_from_ids(
# May extend to include chunk ranges
selected_llm_docs = inference_documents_from_ids(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
@@ -465,23 +327,9 @@ def stream_chat_message_objects(
else default_num_chunks
),
max_window_percentage=max_document_percentage,
use_sections=new_msg_req.chunks_above > 0
or new_msg_req.chunks_below > 0,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
reserved_assistant_message_id=reserved_message_id,
)
overridden_model = (
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
)
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
@@ -489,134 +337,83 @@ def stream_chat_message_objects(
chat_session_id=chat_session_id,
parent_message=final_msg,
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
# rephrased_query=,
# token_count=,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
# error=,
# reference_docs=,
db_session=db_session,
commit=False,
commit=True,
)
if not final_msg.prompt:
raise RuntimeError("No Prompt found")
prompt_config = (
PromptConfig.from_model(
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
)
if not persona
else PromptConfig.from_model(persona.prompts[0])
prompt_config = PromptConfig.from_model(
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
)
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = llm.config
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name=openai_provider.default_model_name,
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(api_key=bing_api_key)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema(
db_tool_model.openapi_schema
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
persona_tool_classes = [
get_tool_cls(tool, db_session) for tool in persona.tools
]
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
persona_tool_classes
)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
llm.config.model_provider, llm.config.model_name
)
# NOTE: for now, only support SearchTool and ImageGenerationTool
# in the future, will support arbitrary user-defined tools
search_tool: SearchTool | None = None
tools: list[Tool] = []
for tool_cls in persona_tool_classes:
if tool_cls.__name__ == SearchTool.__name__:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm_config=llm.config,
pruning_config=document_pruning_config,
selected_docs=selected_llm_docs,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
)
tools.append(search_tool)
elif tool_cls.__name__ == ImageGenerationTool.__name__:
dalle_key = None
if llm and llm.config.api_key and llm.config.model_provider == "openai":
dalle_key = llm.config.api_key
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
dalle_key = openai_provider.api_key
tools.append(ImageGenerationTool(api_key=dalle_key))
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=AnswerStyleConfig(
@@ -628,69 +425,34 @@ def stream_chat_message_objects(
prompt_config=prompt_config,
llm=(
llm
or get_main_llm_from_tuple(
get_llms_for_persona(
persona=persona,
llm_override=(
new_msg_req.llm_override or chat_session.llm_override
),
additional_headers=litellm_additional_headers,
)
or get_llm_for_persona(
persona, new_msg_req.llm_override or chat_session.llm_override
)
),
message_history=[
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
tools=tools,
force_use_tool=_get_force_search_settings(new_msg_req, tools),
force_use_tool=_check_should_force_search(new_msg_req),
)
reference_db_search_docs = None
qa_docs_response = None
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
packet, db_session, selected_db_search_docs
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=llm_indices
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
@@ -706,39 +468,16 @@ def stream_chat_message_objects(
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
error_msg = str(e)
logger.exception(f"Failed to process chat message: {error_msg}")
logger.exception(e)
stack_trace = traceback.format_exc()
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]")
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
db_session.rollback()
# Frontend will erase whatever answer and show this instead
# This will be the issue 99% of the time
yield StreamingError(error="LLM failed to respond, have you set your API key?")
return
# Post-LLM answer processing
@@ -751,13 +490,7 @@ def stream_chat_message_objects(
)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@@ -767,29 +500,15 @@ def stream_chat_message_objects(
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_calls=[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else [],
)
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
except Exception as e:
error_msg = str(e)
logger.exception(error_msg)
logger.exception(e)
# Frontend will erase whatever answer and show this instead
yield StreamingError(error="Failed to parse LLM output")
@@ -800,8 +519,6 @@ def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
@@ -809,8 +526,6 @@ def stream_chat_message(
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.model_dump())
yield get_json_line(obj.dict())

Some files were not shown because too many files have changed in this diff Show More