mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
1 Commits
redirect-a
...
eval/split
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4293543a6a |
76
.github/actions/custom-build-and-push/action.yml
vendored
76
.github/actions/custom-build-and-push/action.yml
vendored
@@ -1,76 +0,0 @@
|
||||
name: 'Build and Push Docker Image with Retry'
|
||||
description: 'Attempts to build and push a Docker image, with a retry on failure'
|
||||
inputs:
|
||||
context:
|
||||
description: 'Build context'
|
||||
required: true
|
||||
file:
|
||||
description: 'Dockerfile location'
|
||||
required: true
|
||||
platforms:
|
||||
description: 'Target platforms'
|
||||
required: true
|
||||
pull:
|
||||
description: 'Always attempt to pull a newer version of the image'
|
||||
required: false
|
||||
default: 'true'
|
||||
push:
|
||||
description: 'Push the image to registry'
|
||||
required: false
|
||||
default: 'true'
|
||||
load:
|
||||
description: 'Load the image into Docker daemon'
|
||||
required: false
|
||||
default: 'true'
|
||||
tags:
|
||||
description: 'Image tags'
|
||||
required: true
|
||||
cache-from:
|
||||
description: 'Cache sources'
|
||||
required: false
|
||||
cache-to:
|
||||
description: 'Cache destinations'
|
||||
required: false
|
||||
retry-wait-time:
|
||||
description: 'Time to wait before retry in seconds'
|
||||
required: false
|
||||
default: '5'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build and push Docker image (First Attempt)
|
||||
id: buildx1
|
||||
uses: docker/build-push-action@v5
|
||||
continue-on-error: true
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
- name: Wait to retry
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
run: |
|
||||
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
|
||||
sleep ${{ inputs.retry-wait-time }}
|
||||
shell: bash
|
||||
|
||||
- name: Build and push Docker image (Retry Attempt)
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
25
.github/pull_request_template.md
vendored
25
.github/pull_request_template.md
vendored
@@ -1,25 +0,0 @@
|
||||
## Description
|
||||
[Provide a brief description of the changes in this PR]
|
||||
|
||||
|
||||
## How Has This Been Tested?
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk
|
||||
[Any know risks or failure modes to point out to reviewers]
|
||||
|
||||
|
||||
## Related Issue(s)
|
||||
[If applicable, link to the issue(s) this PR addresses]
|
||||
|
||||
|
||||
## Checklist:
|
||||
- [ ] All of the automated tests pass
|
||||
- [ ] All PR comments are addressed and marked resolved
|
||||
- [ ] If there are migrations, they have been rebased to latest main
|
||||
- [ ] If there are new dependencies, they are added to the requirements
|
||||
- [ ] If there are new environment variables, they are added to all of the deployment methods
|
||||
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- [ ] Docker images build and basic functionalities work
|
||||
- [ ] Author has done a final read through of the PR right before merge
|
||||
33
.github/workflows/docker-build-backend-container-on-merge-group.yml
vendored
Normal file
33
.github/workflows/docker-build-backend-container-on-merge-group.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Build Backend Image on Merge Group
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
|
||||
jobs:
|
||||
build:
|
||||
# TODO: make this a matrix build like the web containers
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Backend Image Docker Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: false
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=v0.0.1
|
||||
@@ -7,8 +7,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
53
.github/workflows/docker-build-web-container-on-merge-group.yml
vendored
Normal file
53
.github/workflows/docker-build-web-container-on-merge-group.yml
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
name: Build Web Image on Merge Group
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: false
|
||||
build-args: |
|
||||
DANSWER_VERSION=v0.0.1
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
@@ -1,67 +0,0 @@
|
||||
# This workflow is intentionally disabled while we're still working on it
|
||||
# It's close to ready, but a race condition needs to be fixed with
|
||||
# API server and Vespa startup, and it needs to have a way to build/test against
|
||||
# local containers
|
||||
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
runs-on: Amd64
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Run chart-testing (lint)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --config ct.yaml
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
1
.github/workflows/pr-python-checks.yml
vendored
1
.github/workflows/pr-python-checks.yml
vendored
@@ -1,7 +1,6 @@
|
||||
name: Python Checks
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
|
||||
57
.github/workflows/pr-python-connector-tests.yml
vendored
57
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -1,57 +0,0 @@
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
schedule:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
env:
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
4
.github/workflows/pr-python-tests.yml
vendored
4
.github/workflows/pr-python-tests.yml
vendored
@@ -1,7 +1,6 @@
|
||||
name: Python Unit Tests
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
@@ -11,8 +10,7 @@ jobs:
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
19
.github/workflows/pr-quality-checks.yml
vendored
19
.github/workflows/pr-quality-checks.yml
vendored
@@ -4,19 +4,18 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request: null
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
with:
|
||||
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
with:
|
||||
extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
160
.github/workflows/run-it.yml
vendored
160
.github/workflows/run-it.yml
vendored
@@ -1,160 +0,0 @@
|
||||
name: Run Integration Tests
|
||||
concurrency:
|
||||
group: Run-Integration-Tests-${{ github.head_ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
runs-on:
|
||||
group: 'arm64-image-builders'
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# NOTE: we don't need to build the Web Docker image since it's not used
|
||||
# during the IT for now. We have a separate action to verify it builds
|
||||
# succesfully
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull danswer/danswer-web-server:latest
|
||||
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it
|
||||
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: danswer/danswer-backend:it
|
||||
cache-from: type=registry,ref=danswer/danswer-backend:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/danswer-backend:it,mode=max
|
||||
type=inline
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: danswer/danswer-model-server:it
|
||||
cache-from: type=registry,ref=danswer/danswer-model-server:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/danswer-model-server:it,mode=max
|
||||
type=inline
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: danswer/integration-test-runner:it
|
||||
cache-from: type=registry,ref=danswer/integration-test-runner:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/integration-test-runner:it,mode=max
|
||||
type=inline
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=it \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run integration tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
danswer/integration-test-runner:it
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: docker-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -4,6 +4,6 @@
|
||||
.mypy_cache
|
||||
.idea
|
||||
/deployment/data/nginx/app.conf
|
||||
.vscode/
|
||||
.vscode/launch.json
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
|
||||
15
.vscode/env_template.txt
vendored
15
.vscode/env_template.txt
vendored
@@ -1,5 +1,5 @@
|
||||
# Copy this file to .env in the .vscode folder
|
||||
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
|
||||
# Copy this file to .env at the base of the repo and fill in the <REPLACE THIS> values
|
||||
# This will help with development iteration speed and reduce repeat tasks for dev
|
||||
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
|
||||
|
||||
# For local dev, often user Authentication is not needed
|
||||
@@ -15,7 +15,7 @@ LOG_LEVEL=debug
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to answer generation
|
||||
# This step is quite heavy on token usage so we disable it for dev generally
|
||||
DISABLE_LLM_DOC_RELEVANCE=False
|
||||
DISABLE_LLM_CHUNK_FILTER=True
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
@@ -27,9 +27,9 @@ REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-3.5-turbo
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
|
||||
|
||||
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using DanswerBot
|
||||
@@ -38,7 +38,7 @@ FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
|
||||
# Python stuff
|
||||
PYTHONPATH=../backend
|
||||
PYTHONPATH=./backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
@@ -49,3 +49,4 @@ BING_API_KEY=<REPLACE THIS>
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
|
||||
64
.vscode/launch.template.jsonc
vendored
64
.vscode/launch.template.jsonc
vendored
@@ -1,23 +1,15 @@
|
||||
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
|
||||
/*
|
||||
|
||||
Copy this file into '.vscode/launch.json' or merge its
|
||||
contents into your existing configurations.
|
||||
|
||||
*/
|
||||
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
"name": "Run All Danswer Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Indexing",
|
||||
"Background Jobs",
|
||||
"Slack Bot"
|
||||
]
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Web Server",
|
||||
@@ -25,7 +17,7 @@
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
@@ -33,12 +25,11 @@
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
@@ -52,12 +43,11 @@
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
@@ -72,14 +62,13 @@
|
||||
},
|
||||
{
|
||||
"name": "Indexing",
|
||||
"consoleName": "Indexing",
|
||||
"type": "debugpy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"ENABLE_MINI_CHUNK": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
@@ -88,12 +77,11 @@
|
||||
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
|
||||
{
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
@@ -108,12 +96,11 @@
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/danswerbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
@@ -122,12 +109,11 @@
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
@@ -138,16 +124,6 @@
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
105
CONTRIBUTING.md
105
CONTRIBUTING.md
@@ -48,26 +48,23 @@ We would love to see you there!
|
||||
|
||||
|
||||
## Get Started 🚀
|
||||
Danswer being a fully functional app, relies on some external software, specifically:
|
||||
Danswer being a fully functional app, relies on some external pieces of software, specifically:
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to build and run Danswer locally from source with Docker containers providing the above external software. We believe this combination is easier for
|
||||
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Danswer stack within Docker below.
|
||||
This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for
|
||||
development purposes but also feel free to just use the containers and update with local changes by providing the
|
||||
`--build` flag.
|
||||
|
||||
|
||||
### Local Set Up
|
||||
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
|
||||
It is recommended to use Python version 3.11
|
||||
|
||||
If using a lower version, modifications will have to be made to the code.
|
||||
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
|
||||
If using a higher version, the version of Tensorflow we use may not be available for your platform.
|
||||
|
||||
|
||||
#### Backend: Python requirements
|
||||
#### Installing Requirements
|
||||
Currently, we use pip and recommend creating a virtual environment.
|
||||
|
||||
For convenience here's a command for it:
|
||||
@@ -76,9 +73,8 @@ python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs.
|
||||
> For simplicity, we recommend setting up the virtual environment outside of the danswer directory.
|
||||
--> Note that this virtual environment MUST NOT be set up WITHIN the danswer
|
||||
directory
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
```bash
|
||||
@@ -93,38 +89,34 @@ Install the required python dependencies:
|
||||
```bash
|
||||
pip install -r danswer/backend/requirements/default.txt
|
||||
pip install -r danswer/backend/requirements/dev.txt
|
||||
pip install -r danswer/backend/requirements/ee.txt
|
||||
pip install -r danswer/backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
|
||||
You may have to deactivate and reactivate your virtualenv for `playwright` to appear on your path.
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||
Once the above is done, navigate to `danswer/web` run:
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
#### Docker containers for external software
|
||||
You will need Docker installed to run these containers.
|
||||
Install Playwright (required by the Web Connector)
|
||||
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
|
||||
> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again.
|
||||
This will update the path to include playwright
|
||||
|
||||
Then install Playwright by running:
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache
|
||||
playwright install
|
||||
```
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
|
||||
#### Running Danswer locally
|
||||
#### Dependent Docker Containers
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up Vespa and Postgres with:
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db
|
||||
```
|
||||
(index refers to Vespa and relational_db refers to Postgres)
|
||||
|
||||
#### Running Danswer
|
||||
To start the frontend, navigate to `danswer/web` and run:
|
||||
```bash
|
||||
npm run dev
|
||||
@@ -135,10 +127,11 @@ Navigate to `danswer/backend` and run:
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
powershell -Command "
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
"
|
||||
```
|
||||
|
||||
The first time running Danswer, you will need to run the DB migrations for Postgres.
|
||||
@@ -161,7 +154,6 @@ To run the backend API server, navigate back to `danswer/backend` and run:
|
||||
```bash
|
||||
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "
|
||||
@@ -170,58 +162,20 @@ powershell -Command "
|
||||
"
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
#### Wrapping up
|
||||
|
||||
You should now have 4 servers running:
|
||||
|
||||
- Web server
|
||||
- Backend API
|
||||
- Model server
|
||||
- Background jobs
|
||||
|
||||
Now, visit `http://localhost:3000` in your browser. You should see the Danswer onboarding wizard where you can connect your external LLM provider to Danswer.
|
||||
|
||||
You've successfully set up a local Danswer instance! 🏁
|
||||
|
||||
#### Running the Danswer application in a container
|
||||
|
||||
You can run the full Danswer application stack from pre-built images including all external software dependencies.
|
||||
|
||||
Navigate to `danswer/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Danswer.
|
||||
|
||||
If you want to make changes to Danswer and run those changes in Docker, you can also build a local version of the Danswer container images that incorporates your changes like so:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
|
||||
```
|
||||
Note: if you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
### Formatting and Linting
|
||||
#### Backend
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `danswer/backend` directory, run:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Danswer is fully type-annotated, and we want to keep it that way!
|
||||
Danswer is fully type-annotated, and we would like to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
|
||||
|
||||
|
||||
@@ -232,7 +186,6 @@ Please double check that prettier passes before creating a pull request.
|
||||
|
||||
|
||||
### Release Process
|
||||
Danswer loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
Danswer follows the semver versioning standard.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
## Some additional notes for Mac Users
|
||||
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Setting up Python
|
||||
Ensure [Homebrew](https://brew.sh/) is already set up.
|
||||
|
||||
Then install python 3.11.
|
||||
```bash
|
||||
brew install python@3.11
|
||||
```
|
||||
|
||||
Add python 3.11 to your path: add the following line to ~/.zshrc
|
||||
```
|
||||
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> You will need to open a new terminal for the path change above to take effect.
|
||||
|
||||
|
||||
### Setting up Docker
|
||||
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
|
||||
ensure it is running before continuing with the docker commands.
|
||||
|
||||
|
||||
### Formatting and Linting
|
||||
MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly.
|
||||
After installing pre-commit, run the following command:
|
||||
```bash
|
||||
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
```
|
||||
@@ -9,8 +9,7 @@ founders@danswer.ai for more information. Please visit https://github.com/danswe
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# Install system dependencies
|
||||
@@ -69,15 +68,13 @@ RUN apt-get update && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('wordnet', quiet=True); \
|
||||
nltk.download('punkt', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
|
||||
@@ -8,10 +8,7 @@ visit https://github.com/danswer-ai/danswer."
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
@@ -21,22 +18,14 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
# Download tokenizers, distilbert for the Danswer model
|
||||
# Download model weights
|
||||
# Run Nomic to pull in the custom architecture and have it cached locally
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
|
||||
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
|
||||
# running Danswer, don't overwrite it with the built in cache folder
|
||||
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
|
||||
AutoTokenizer.from_pretrained('danswer/intent-model'); \
|
||||
AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
snapshot_download('danswer/intent-model'); \
|
||||
snapshot_download('intfloat/e5-base-v2'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')"
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@@ -16,9 +15,7 @@ config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
@@ -32,20 +29,6 @@ target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
@@ -72,11 +55,7 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
) # type: ignore
|
||||
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
@@ -17,11 +17,15 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("current_alternate_model", sa.String(), nullable=True),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("chat_session", "current_alternate_model")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
"""add_indexing_start_to_connector
|
||||
|
||||
Revision ID: 08a1eda20fe1
|
||||
Revises: 8a87bd6ec550
|
||||
Create Date: 2024-07-23 11:12:39.462397
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "08a1eda20fe1"
|
||||
down_revision = "8a87bd6ec550"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector", sa.Column("indexing_start", sa.DateTime(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector", "indexing_start")
|
||||
@@ -1,27 +0,0 @@
|
||||
"""add ccpair deletion failure message
|
||||
|
||||
Revision ID: 0ebb1d516877
|
||||
Revises: 52a219fb5233
|
||||
Create Date: 2024-09-10 15:03:48.233926
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0ebb1d516877"
|
||||
down_revision = "52a219fb5233"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("deletion_failure_message", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "deletion_failure_message")
|
||||
@@ -1,135 +0,0 @@
|
||||
"""embedding model -> search settings
|
||||
|
||||
Revision ID: 1f60f60c3401
|
||||
Revises: f17bf3b0d9f1
|
||||
Create Date: 2024-08-25 12:39:51.731632
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1f60f60c3401"
|
||||
down_revision = "f17bf3b0d9f1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
# Rename the table
|
||||
op.rename_table("embedding_model", "search_settings")
|
||||
|
||||
# Add new columns
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"multipass_indexing", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"multilingual_expansion",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"disable_rerank_for_streaming",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"num_rerank",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=str(NUM_POSTPROCESSED_RESULTS),
|
||||
),
|
||||
)
|
||||
|
||||
# Add the new column as nullable initially
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("search_settings_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the new column with data from the existing embedding_model_id
|
||||
op.execute("UPDATE index_attempt SET search_settings_id = embedding_model_id")
|
||||
|
||||
# Create the foreign key constraint
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Make the new column non-nullable
|
||||
op.alter_column("index_attempt", "search_settings_id", nullable=False)
|
||||
|
||||
# Drop the old embedding_model_id column
|
||||
op.drop_column("index_attempt", "embedding_model_id")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the embedding_model_id column
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the old column with data from search_settings_id
|
||||
op.execute("UPDATE index_attempt SET embedding_model_id = search_settings_id")
|
||||
|
||||
# Make the old column non-nullable
|
||||
op.alter_column("index_attempt", "embedding_model_id", nullable=False)
|
||||
|
||||
# Drop the foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Drop the new search_settings_id column
|
||||
op.drop_column("index_attempt", "search_settings_id")
|
||||
|
||||
# Rename the table back
|
||||
op.rename_table("search_settings", "embedding_model")
|
||||
|
||||
# Remove added columns
|
||||
op.drop_column("embedding_model", "num_rerank")
|
||||
op.drop_column("embedding_model", "rerank_api_key")
|
||||
op.drop_column("embedding_model", "rerank_provider_type")
|
||||
op.drop_column("embedding_model", "rerank_model_name")
|
||||
op.drop_column("embedding_model", "disable_rerank_for_streaming")
|
||||
op.drop_column("embedding_model", "multilingual_expansion")
|
||||
op.drop_column("embedding_model", "multipass_indexing")
|
||||
|
||||
op.create_foreign_key(
|
||||
"index_attempt__embedding_model_fk",
|
||||
"index_attempt",
|
||||
"embedding_model",
|
||||
["embedding_model_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,44 +0,0 @@
|
||||
"""notifications
|
||||
|
||||
Revision ID: 213fd978c6d8
|
||||
Revises: 5fc1f54cc252
|
||||
Create Date: 2024-08-10 11:13:36.070790
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "213fd978c6d8"
|
||||
down_revision = "5fc1f54cc252"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"notification",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"notif_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sa.UUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("dismissed", sa.Boolean(), nullable=False),
|
||||
sa.Column("last_shown", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("first_shown", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("notification")
|
||||
@@ -79,7 +79,7 @@ def downgrade() -> None:
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
"document_retrieval",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
|
||||
@@ -160,28 +160,12 @@ def downgrade() -> None:
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if the constraint exists before dropping
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
constraints = inspector.get_foreign_keys("index_attempt")
|
||||
|
||||
if any(
|
||||
constraint["name"] == "fk_index_attempt_credential_id"
|
||||
for constraint in constraints
|
||||
):
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
if any(
|
||||
constraint["name"] == "fk_index_attempt_connector_id"
|
||||
for constraint in constraints
|
||||
):
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_table("connector_credential_pair")
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
"""Add Above Below to Persona
|
||||
|
||||
Revision ID: 2d2304e27d8c
|
||||
Revises: 4b08d97e175a
|
||||
Create Date: 2024-08-21 19:15:15.762948
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2d2304e27d8c"
|
||||
down_revision = "4b08d97e175a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("chunks_above", sa.Integer(), nullable=True))
|
||||
op.add_column("persona", sa.Column("chunks_below", sa.Integer(), nullable=True))
|
||||
|
||||
op.execute(
|
||||
"UPDATE persona SET chunks_above = 1, chunks_below = 1 WHERE chunks_above IS NULL AND chunks_below IS NULL"
|
||||
)
|
||||
|
||||
op.alter_column("persona", "chunks_above", nullable=False)
|
||||
op.alter_column("persona", "chunks_below", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "chunks_below")
|
||||
op.drop_column("persona", "chunks_above")
|
||||
@@ -1,70 +0,0 @@
|
||||
"""Add icon_color and icon_shape to Persona
|
||||
|
||||
Revision ID: 325975216eb3
|
||||
Revises: 91ffac7e65b3
|
||||
Create Date: 2024-07-24 21:29:31.784562
|
||||
|
||||
"""
|
||||
import random
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column, select
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "325975216eb3"
|
||||
down_revision = "91ffac7e65b3"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
colorOptions = [
|
||||
"#FF6FBF",
|
||||
"#6FB1FF",
|
||||
"#B76FFF",
|
||||
"#FFB56F",
|
||||
"#6FFF8D",
|
||||
"#FF6F6F",
|
||||
"#6FFFFF",
|
||||
]
|
||||
|
||||
|
||||
# Function to generate a random shape ensuring at least 3 of the middle 4 squares are filled
|
||||
def generate_random_shape() -> int:
|
||||
center_squares = [12, 10, 6, 14, 13, 11, 7, 15]
|
||||
center_fill = random.choice(center_squares)
|
||||
remaining_squares = [i for i in range(16) if not (center_fill & (1 << i))]
|
||||
random.shuffle(remaining_squares)
|
||||
for i in range(10 - bin(center_fill).count("1")):
|
||||
center_fill |= 1 << remaining_squares[i]
|
||||
return center_fill
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True))
|
||||
op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True))
|
||||
op.add_column("persona", sa.Column("uploaded_image_id", sa.String(), nullable=True))
|
||||
|
||||
persona = table(
|
||||
"persona",
|
||||
column("id", sa.Integer),
|
||||
column("icon_color", sa.String),
|
||||
column("icon_shape", sa.Integer),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
personas = conn.execute(select(persona.c.id))
|
||||
|
||||
for persona_id in personas:
|
||||
random_color = random.choice(colorOptions)
|
||||
random_shape = generate_random_shape()
|
||||
conn.execute(
|
||||
persona.update()
|
||||
.where(persona.c.id == persona_id[0])
|
||||
.values(icon_color=random_color, icon_shape=random_shape)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "icon_shape")
|
||||
op.drop_column("persona", "uploaded_image_id")
|
||||
op.drop_column("persona", "icon_color")
|
||||
@@ -1,90 +0,0 @@
|
||||
"""Add curator fields
|
||||
|
||||
Revision ID: 351faebd379d
|
||||
Revises: ee3f4b47fad5
|
||||
Create Date: 2024-08-15 22:37:08.397052
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "351faebd379d"
|
||||
down_revision = "ee3f4b47fad5"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add is_curator column to User__UserGroup table
|
||||
op.add_column(
|
||||
"user__user_group",
|
||||
sa.Column("is_curator", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
|
||||
# Use batch mode to modify the enum type
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC",
|
||||
"ADMIN",
|
||||
"CURATOR",
|
||||
"GLOBAL_CURATOR",
|
||||
name="userrole",
|
||||
native_enum=False,
|
||||
),
|
||||
existing_type=sa.Enum("BASIC", "ADMIN", name="userrole", native_enum=False),
|
||||
existing_nullable=False,
|
||||
)
|
||||
# Create the association table
|
||||
op.create_table(
|
||||
"credential__user_group",
|
||||
sa.Column("credential_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["credential_id"],
|
||||
["credential.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("credential_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"curator_public", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Update existing records to ensure they fit within the BASIC/ADMIN roles
|
||||
op.execute(
|
||||
"UPDATE \"user\" SET role = 'ADMIN' WHERE role IN ('CURATOR', 'GLOBAL_CURATOR')"
|
||||
)
|
||||
|
||||
# Remove is_curator column from User__UserGroup table
|
||||
op.drop_column("user__user_group", "is_curator")
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
|
||||
),
|
||||
existing_type=sa.Enum(
|
||||
"BASIC",
|
||||
"ADMIN",
|
||||
"CURATOR",
|
||||
"GLOBAL_CURATOR",
|
||||
name="userrole",
|
||||
native_enum=False,
|
||||
),
|
||||
existing_nullable=False,
|
||||
)
|
||||
# Drop the association table
|
||||
op.drop_table("credential__user_group")
|
||||
op.drop_column("credential", "curator_public")
|
||||
@@ -18,6 +18,7 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
@@ -28,8 +29,10 @@ def upgrade() -> None:
|
||||
["alternate_assistant_id"],
|
||||
["id"],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "alternate_assistant_id")
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Rename index_origin to index_recursively
|
||||
|
||||
Revision ID: 1d6ad76d1f37
|
||||
Revises: e1392f05e840
|
||||
Create Date: 2024-08-01 12:38:54.466081
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d6ad76d1f37"
|
||||
down_revision = "e1392f05e840"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{index_recursively}',
|
||||
'true'::jsonb
|
||||
) - 'index_origin'
|
||||
WHERE connector_specific_config ? 'index_origin'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{index_origin}',
|
||||
connector_specific_config->'index_recursively'
|
||||
) - 'index_recursively'
|
||||
WHERE connector_specific_config ? 'index_recursively'
|
||||
"""
|
||||
)
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Add display_model_names to llm_provider
|
||||
|
||||
Revision ID: 473a1a7ca408
|
||||
Revises: 325975216eb3
|
||||
Create Date: 2024-07-25 14:31:02.002917
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "473a1a7ca408"
|
||||
down_revision = "325975216eb3"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
default_models_by_provider = {
|
||||
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],
|
||||
"bedrock": [
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
],
|
||||
"anthropic": ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620"],
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("display_model_names", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
|
||||
connection = op.get_bind()
|
||||
for provider, models in default_models_by_provider.items():
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"UPDATE llm_provider SET display_model_names = :models WHERE provider = :provider"
|
||||
),
|
||||
{"models": models, "provider": provider},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "display_model_names")
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Moved status to connector credential pair
|
||||
|
||||
Revision ID: 4a951134c801
|
||||
Revises: 7477a5f5d728
|
||||
Create Date: 2024-08-10 19:20:34.527559
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4a951134c801"
|
||||
down_revision = "7477a5f5d728"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"ACTIVE",
|
||||
"PAUSED",
|
||||
"DELETING",
|
||||
name="connectorcredentialpairstatus",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update status of connector_credential_pair based on connector's disabled status
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector_credential_pair
|
||||
SET status = CASE
|
||||
WHEN (
|
||||
SELECT disabled
|
||||
FROM connector
|
||||
WHERE connector.id = connector_credential_pair.connector_id
|
||||
) = FALSE THEN 'ACTIVE'
|
||||
ELSE 'PAUSED'
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the status column not nullable after setting values
|
||||
op.alter_column("connector_credential_pair", "status", nullable=False)
|
||||
|
||||
op.drop_column("connector", "disabled")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column("disabled", sa.BOOLEAN(), autoincrement=False, nullable=True),
|
||||
)
|
||||
|
||||
# Update disabled status of connector based on connector_credential_pair's status
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET disabled = CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1
|
||||
FROM connector_credential_pair
|
||||
WHERE connector_credential_pair.connector_id = connector.id
|
||||
AND connector_credential_pair.status = 'ACTIVE'
|
||||
) THEN FALSE
|
||||
ELSE TRUE
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the disabled column not nullable after setting values
|
||||
op.alter_column("connector", "disabled", nullable=False)
|
||||
|
||||
op.drop_column("connector_credential_pair", "status")
|
||||
@@ -1,34 +0,0 @@
|
||||
"""change default prune_freq
|
||||
|
||||
Revision ID: 4b08d97e175a
|
||||
Revises: d9ec13955951
|
||||
Create Date: 2024-08-20 15:28:52.993827
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4b08d97e175a"
|
||||
down_revision = "d9ec13955951"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET prune_freq = 2592000
|
||||
WHERE prune_freq = 86400
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET prune_freq = 86400
|
||||
WHERE prune_freq = 2592000
|
||||
"""
|
||||
)
|
||||
@@ -1,72 +0,0 @@
|
||||
"""Add type to credentials
|
||||
|
||||
Revision ID: 4ea2c93919c1
|
||||
Revises: 473a1a7ca408
|
||||
Create Date: 2024-07-18 13:07:13.655895
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4ea2c93919c1"
|
||||
down_revision = "473a1a7ca408"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new 'source' column to the 'credential' table
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"source",
|
||||
sa.String(length=100), # Use String instead of Enum
|
||||
nullable=True, # Initially allow NULL values
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"name",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create a temporary table that maps each credential to a single connector source.
|
||||
# This is needed because a credential can be associated with multiple connectors,
|
||||
# but we want to assign a single source to each credential.
|
||||
# We use DISTINCT ON to ensure we only get one row per credential_id.
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TEMPORARY TABLE temp_connector_credential AS
|
||||
SELECT DISTINCT ON (cc.credential_id)
|
||||
cc.credential_id,
|
||||
c.source AS connector_source
|
||||
FROM connector_credential_pair cc
|
||||
JOIN connector c ON cc.connector_id = c.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Update the 'source' column in the 'credential' table
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE credential cred
|
||||
SET source = COALESCE(
|
||||
(SELECT connector_source
|
||||
FROM temp_connector_credential temp
|
||||
WHERE cred.id = temp.credential_id),
|
||||
'NOT_APPLICABLE'
|
||||
)
|
||||
"""
|
||||
)
|
||||
# If no exception was raised, alter the column
|
||||
op.alter_column("credential", "source", nullable=True) # TODO modify
|
||||
# # ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("credential", "source")
|
||||
op.drop_column("credential", "name")
|
||||
@@ -1,66 +0,0 @@
|
||||
"""Add last synced and last modified to document table
|
||||
|
||||
Revision ID: 52a219fb5233
|
||||
Revises: f17bf3b0d9f1
|
||||
Create Date: 2024-08-28 17:40:46.077470
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "52a219fb5233"
|
||||
down_revision = "f7e58d357687"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# last modified represents the last time anything needing syncing to vespa changed
|
||||
# including row metadata and the document itself. This obviously does not include
|
||||
# the last_synced column.
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"last_modified",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# last synced represents the last time this document was synced to Vespa
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
# Set last_synced to the same value as last_modified for existing rows
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE document
|
||||
SET last_synced = last_modified
|
||||
"""
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_document_last_modified"),
|
||||
"document",
|
||||
["last_modified"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_document_last_synced"),
|
||||
"document",
|
||||
["last_synced"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_document_last_synced"), table_name="document")
|
||||
op.drop_index(op.f("ix_document_last_modified"), table_name="document")
|
||||
op.drop_column("document", "last_synced")
|
||||
op.drop_column("document", "last_modified")
|
||||
@@ -1,25 +0,0 @@
|
||||
"""hybrid-enum
|
||||
|
||||
Revision ID: 5fc1f54cc252
|
||||
Revises: 1d6ad76d1f37
|
||||
Create Date: 2024-08-06 15:35:40.278485
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5fc1f54cc252"
|
||||
down_revision = "1d6ad76d1f37"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("persona", "search_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column("persona", sa.Column("search_type", sa.String(), nullable=True))
|
||||
op.execute("UPDATE persona SET search_type = 'SEMANTIC'")
|
||||
op.alter_column("persona", "search_type", nullable=False)
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Added model defaults for users
|
||||
|
||||
Revision ID: 7477a5f5d728
|
||||
Revises: 213fd978c6d8
|
||||
Create Date: 2024-08-04 19:00:04.512634
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7477a5f5d728"
|
||||
down_revision = "213fd978c6d8"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("user", sa.Column("default_model", sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "default_model")
|
||||
@@ -28,9 +28,5 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_unique_constraint(
|
||||
"connector_credential_pair__name__key", "connector_credential_pair", ["name"]
|
||||
)
|
||||
op.alter_column(
|
||||
"connector_credential_pair", "name", existing_type=sa.String(), nullable=True
|
||||
)
|
||||
# This wasn't really required by the code either, no good reason to make it unique again
|
||||
pass
|
||||
|
||||
@@ -10,7 +10,7 @@ import sqlalchemy as sa
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""add_llm_group_permissions_control
|
||||
|
||||
Revision ID: 795b20b85b4b
|
||||
Revises: 05c07bf07c00
|
||||
Create Date: 2024-07-19 11:54:35.701558
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "795b20b85b4b"
|
||||
down_revision = "05c07bf07c00"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_provider__user_group",
|
||||
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["llm_provider_id"],
|
||||
["llm_provider.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("llm_provider__user_group")
|
||||
op.drop_column("llm_provider", "is_public")
|
||||
@@ -1,107 +0,0 @@
|
||||
"""associate index attempts with ccpair
|
||||
|
||||
Revision ID: 8a87bd6ec550
|
||||
Revises: 4ea2c93919c1
|
||||
Create Date: 2024-07-22 15:15:52.558451
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8a87bd6ec550"
|
||||
down_revision = "4ea2c93919c1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new connector_credential_pair_id column
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
# Create a foreign key constraint to the connector_credential_pair table
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
"connector_credential_pair",
|
||||
["connector_credential_pair_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Populate the new connector_credential_pair_id column using existing connector_id and credential_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_credential_pair_id = (
|
||||
SELECT id FROM connector_credential_pair ccp
|
||||
WHERE
|
||||
(ia.connector_id IS NULL OR ccp.connector_id = ia.connector_id)
|
||||
AND (ia.credential_id IS NULL OR ccp.credential_id = ia.credential_id)
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE ia.connector_id IS NOT NULL OR ia.credential_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# For good measure
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM index_attempt
|
||||
WHERE connector_credential_pair_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the new connector_credential_pair_id column non-nullable
|
||||
op.alter_column("index_attempt", "connector_credential_pair_id", nullable=False)
|
||||
|
||||
# Drop the old connector_id and credential_id columns
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
|
||||
# Update the index to use connector_credential_pair_id
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_credential_pair_id", "time_created"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old connector_id and credential_id columns
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the old connector_id and credential_id columns using the connector_credential_pair_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_id = ccp.connector_id, credential_id = ccp.credential_id
|
||||
FROM connector_credential_pair ccp
|
||||
WHERE ia.connector_credential_pair_id = ccp.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the old connector_id and credential_id columns non-nullable
|
||||
op.alter_column("index_attempt", "connector_id", nullable=False)
|
||||
op.alter_column("index_attempt", "credential_id", nullable=False)
|
||||
|
||||
# Drop the new connector_credential_pair_id column
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_column("index_attempt", "connector_credential_pair_id")
|
||||
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_id", "credential_id", "time_created"],
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""add expiry time
|
||||
|
||||
Revision ID: 91ffac7e65b3
|
||||
Revises: bc9771dccadf
|
||||
Create Date: 2024-06-24 09:39:56.462242
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91ffac7e65b3"
|
||||
down_revision = "795b20b85b4b"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "oidc_expiry")
|
||||
@@ -1,158 +0,0 @@
|
||||
"""migration confluence to be explicit
|
||||
|
||||
Revision ID: a3795dce87be
|
||||
Revises: 1f60f60c3401
|
||||
Create Date: 2024-09-01 13:52:12.006740
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.sql import table, column
|
||||
|
||||
revision = "a3795dce87be"
|
||||
down_revision = "1f60f60c3401"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/spaces')[0]}"
|
||||
path_parts = parsed_url.path.split("/")
|
||||
space = path_parts[3]
|
||||
page_id = path_parts[5] if len(path_parts) > 5 else ""
|
||||
return wiki_base, space, page_id
|
||||
|
||||
def _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url: str,
|
||||
) -> tuple[str, str, str]:
|
||||
DISPLAY = "/display/"
|
||||
PAGE = "/pages/"
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split(DISPLAY)[0]}"
|
||||
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
|
||||
page_id = ""
|
||||
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
|
||||
page_id = content[1]
|
||||
return wiki_base, space, page_id
|
||||
|
||||
is_confluence_cloud = (
|
||||
".atlassian.net/wiki/spaces/" in wiki_url
|
||||
or ".jira.com/wiki/spaces/" in wiki_url
|
||||
)
|
||||
|
||||
if is_confluence_cloud:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(wiki_url)
|
||||
else:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url
|
||||
)
|
||||
|
||||
return wiki_base, space, page_id, is_confluence_cloud
|
||||
|
||||
|
||||
def reconstruct_confluence_url(
|
||||
wiki_base: str, space: str, page_id: str, is_cloud: bool
|
||||
) -> str:
|
||||
if is_cloud:
|
||||
url = f"{wiki_base}/spaces/{space}"
|
||||
if page_id:
|
||||
url += f"/pages/{page_id}"
|
||||
else:
|
||||
url = f"{wiki_base}/display/{space}"
|
||||
if page_id:
|
||||
url += f"/pages/{page_id}"
|
||||
return url
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
connector = table(
|
||||
"connector",
|
||||
column("id", sa.Integer),
|
||||
column("source", sa.String()),
|
||||
column("input_type", sa.String()),
|
||||
column("connector_specific_config", postgresql.JSONB),
|
||||
)
|
||||
|
||||
# Fetch all Confluence connectors
|
||||
connection = op.get_bind()
|
||||
confluence_connectors = connection.execute(
|
||||
sa.select(connector).where(
|
||||
sa.and_(
|
||||
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
|
||||
)
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row in confluence_connectors:
|
||||
config = row.connector_specific_config
|
||||
wiki_page_url = config["wiki_page_url"]
|
||||
wiki_base, space, page_id, is_cloud = extract_confluence_keys_from_url(
|
||||
wiki_page_url
|
||||
)
|
||||
|
||||
new_config = {
|
||||
"wiki_base": wiki_base,
|
||||
"space": space,
|
||||
"page_id": page_id,
|
||||
"is_cloud": is_cloud,
|
||||
}
|
||||
|
||||
for key, value in config.items():
|
||||
if key not in ["wiki_page_url"]:
|
||||
new_config[key] = value
|
||||
|
||||
op.execute(
|
||||
connector.update()
|
||||
.where(connector.c.id == row.id)
|
||||
.values(connector_specific_config=new_config)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
connector = table(
|
||||
"connector",
|
||||
column("id", sa.Integer),
|
||||
column("source", sa.String()),
|
||||
column("input_type", sa.String()),
|
||||
column("connector_specific_config", postgresql.JSONB),
|
||||
)
|
||||
|
||||
confluence_connectors = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.select(connector).where(
|
||||
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
|
||||
)
|
||||
)
|
||||
.fetchall()
|
||||
)
|
||||
|
||||
for row in confluence_connectors:
|
||||
config = row.connector_specific_config
|
||||
if all(key in config for key in ["wiki_base", "space", "is_cloud"]):
|
||||
wiki_page_url = reconstruct_confluence_url(
|
||||
config["wiki_base"],
|
||||
config["space"],
|
||||
config.get("page_id", ""),
|
||||
config["is_cloud"],
|
||||
)
|
||||
|
||||
new_config = {"wiki_page_url": wiki_page_url}
|
||||
new_config.update(
|
||||
{
|
||||
k: v
|
||||
for k, v in config.items()
|
||||
if k not in ["wiki_base", "space", "page_id", "is_cloud"]
|
||||
}
|
||||
)
|
||||
|
||||
op.execute(
|
||||
connector.update()
|
||||
.where(connector.c.id == row.id)
|
||||
.values(connector_specific_config=new_config)
|
||||
)
|
||||
@@ -16,6 +16,7 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -28,9 +29,11 @@ def upgrade() -> None:
|
||||
),
|
||||
nullable=True,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -43,3 +46,4 @@ def downgrade() -> None:
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
"""add support for litellm proxy in reranking
|
||||
|
||||
Revision ID: ba98eba0f66a
|
||||
Revises: bceb1e139447
|
||||
Create Date: 2024-09-06 10:36:04.507332
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ba98eba0f66a"
|
||||
down_revision = "bceb1e139447"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_settings", "rerank_api_url")
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Add base_url to CloudEmbeddingProvider
|
||||
|
||||
Revision ID: bceb1e139447
|
||||
Revises: a3795dce87be
|
||||
Create Date: 2024-08-28 17:00:52.554580
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bceb1e139447"
|
||||
down_revision = "a3795dce87be"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("embedding_provider", "api_url")
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Add index_attempt_errors table
|
||||
|
||||
Revision ID: c5b692fa265c
|
||||
Revises: 4a951134c801
|
||||
Create Date: 2024-08-08 14:06:39.581972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c5b692fa265c"
|
||||
down_revision = "4a951134c801"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"index_attempt_errors",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
|
||||
sa.Column("batch", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"doc_summaries",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column("traceback", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["index_attempt_id"],
|
||||
["index_attempt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"index_attempt_id",
|
||||
"index_attempt_errors",
|
||||
["time_created"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
|
||||
op.drop_table("index_attempt_errors")
|
||||
# ### end Alembic commands ###
|
||||
@@ -19,9 +19,6 @@ depends_on: None = None
|
||||
def upgrade() -> None:
|
||||
op.drop_table("deletion_attempt")
|
||||
|
||||
# Remove the DeletionStatus enum
|
||||
op.execute("DROP TYPE IF EXISTS deletionstatus;")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Remove _alt suffix from model_name
|
||||
|
||||
Revision ID: d9ec13955951
|
||||
Revises: da4c21c69164
|
||||
Create Date: 2024-08-20 16:31:32.955686
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d9ec13955951"
|
||||
down_revision = "da4c21c69164"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_model
|
||||
SET model_name = regexp_replace(model_name, '__danswer_alt_index$', '')
|
||||
WHERE model_name LIKE '%__danswer_alt_index'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We can't reliably add the __danswer_alt_index suffix back, so we'll leave this empty
|
||||
pass
|
||||
@@ -1,65 +0,0 @@
|
||||
"""chosen_assistants changed to jsonb
|
||||
|
||||
Revision ID: da4c21c69164
|
||||
Revises: c5b692fa265c
|
||||
Create Date: 2024-08-18 19:06:47.291491
|
||||
|
||||
"""
|
||||
import json
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "da4c21c69164"
|
||||
down_revision = "c5b692fa265c"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chosen_assistants",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
),
|
||||
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
|
||||
)
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
),
|
||||
{"chosen_assistants": chosen_assistants, "id": id},
|
||||
)
|
||||
@@ -9,7 +9,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from danswer.db.search_settings import (
|
||||
from danswer.db.embedding_model import (
|
||||
get_new_default_embedding_model,
|
||||
get_old_default_embedding_model,
|
||||
user_has_overridden_embedding_model,
|
||||
@@ -71,14 +71,14 @@ def upgrade() -> None:
|
||||
"query_prefix": old_embedding_model.query_prefix,
|
||||
"passage_prefix": old_embedding_model.passage_prefix,
|
||||
"index_name": old_embedding_model.index_name,
|
||||
"status": IndexModelStatus.PRESENT,
|
||||
"status": old_embedding_model.status,
|
||||
}
|
||||
],
|
||||
)
|
||||
# if the user has not overridden the default embedding model via env variables,
|
||||
# insert the new default model into the database to auto-upgrade them
|
||||
if not user_has_overridden_embedding_model():
|
||||
new_embedding_model = get_new_default_embedding_model()
|
||||
new_embedding_model = get_new_default_embedding_model(is_present=False)
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
@@ -136,4 +136,4 @@ def downgrade() -> None:
|
||||
)
|
||||
op.drop_column("index_attempt", "embedding_model_id")
|
||||
op.drop_table("embedding_model")
|
||||
op.execute("DROP TYPE IF EXISTS indexmodelstatus;")
|
||||
op.execute("DROP TYPE indexmodelstatus;")
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Added input prompts
|
||||
|
||||
Revision ID: e1392f05e840
|
||||
Revises: 08a1eda20fe1
|
||||
Create Date: 2024-07-13 19:09:22.556224
|
||||
|
||||
"""
|
||||
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e1392f05e840"
|
||||
down_revision = "08a1eda20fe1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Added alternate model to chat message
|
||||
|
||||
Revision ID: ee3f4b47fad5
|
||||
Revises: 2d2304e27d8c
|
||||
Create Date: 2024-08-12 00:11:50.915845
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ee3f4b47fad5"
|
||||
down_revision = "2d2304e27d8c"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("overridden_model", sa.String(length=255), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "overridden_model")
|
||||
@@ -1,30 +0,0 @@
|
||||
"""standard answer match_regex flag
|
||||
|
||||
Revision ID: efb35676026c
|
||||
Revises: 52a219fb5233
|
||||
Create Date: 2024-09-11 13:55:46.101149
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "efb35676026c"
|
||||
down_revision = "0ebb1d516877"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"standard_answer",
|
||||
sa.Column("match_regex", sa.Boolean(), nullable=False, default=False),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("standard_answer", "match_regex")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,172 +0,0 @@
|
||||
"""embedding provider by provider type
|
||||
|
||||
Revision ID: f17bf3b0d9f1
|
||||
Revises: 351faebd379d
|
||||
Create Date: 2024-08-21 13:13:31.120460
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f17bf3b0d9f1"
|
||||
down_revision = "351faebd379d"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add provider_type column to embedding_provider
|
||||
op.add_column(
|
||||
"embedding_provider",
|
||||
sa.Column("provider_type", sa.String(50), nullable=True),
|
||||
)
|
||||
|
||||
# Update provider_type with existing name values
|
||||
op.execute("UPDATE embedding_provider SET provider_type = UPPER(name)")
|
||||
|
||||
# Make provider_type not nullable
|
||||
op.alter_column("embedding_provider", "provider_type", nullable=False)
|
||||
|
||||
# Drop the foreign key constraint in embedding_model table
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Drop the existing primary key constraint
|
||||
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
|
||||
|
||||
# Create a new primary key constraint on provider_type
|
||||
op.create_primary_key(
|
||||
"embedding_provider_pkey", "embedding_provider", ["provider_type"]
|
||||
)
|
||||
|
||||
# Add provider_type column to embedding_model
|
||||
op.add_column(
|
||||
"embedding_model",
|
||||
sa.Column("provider_type", sa.String(50), nullable=True),
|
||||
)
|
||||
|
||||
# Update provider_type for existing embedding models
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_model
|
||||
SET provider_type = (
|
||||
SELECT provider_type
|
||||
FROM embedding_provider
|
||||
WHERE embedding_provider.id = embedding_model.cloud_provider_id
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the old id column from embedding_provider
|
||||
op.drop_column("embedding_provider", "id")
|
||||
|
||||
# Drop the name column from embedding_provider
|
||||
op.drop_column("embedding_provider", "name")
|
||||
|
||||
# Drop the default_model_id column from embedding_provider
|
||||
op.drop_column("embedding_provider", "default_model_id")
|
||||
|
||||
# Drop the old cloud_provider_id column from embedding_model
|
||||
op.drop_column("embedding_model", "cloud_provider_id")
|
||||
|
||||
# Create the new foreign key constraint
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["provider_type"],
|
||||
["provider_type"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the foreign key constraint in embedding_model table
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Add back the cloud_provider_id column to embedding_model
|
||||
op.add_column(
|
||||
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column("embedding_provider", sa.Column("id", sa.Integer(), nullable=True))
|
||||
|
||||
# Assign incrementing IDs to embedding providers
|
||||
op.execute(
|
||||
"""
|
||||
CREATE SEQUENCE IF NOT EXISTS embedding_provider_id_seq;"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_provider SET id = nextval('embedding_provider_id_seq');
|
||||
"""
|
||||
)
|
||||
|
||||
# Update cloud_provider_id based on provider_type
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_model
|
||||
SET cloud_provider_id = CASE
|
||||
WHEN provider_type IS NULL THEN NULL
|
||||
ELSE (
|
||||
SELECT id
|
||||
FROM embedding_provider
|
||||
WHERE embedding_provider.provider_type = embedding_model.provider_type
|
||||
)
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the provider_type column from embedding_model
|
||||
op.drop_column("embedding_model", "provider_type")
|
||||
|
||||
# Add back the columns to embedding_provider
|
||||
op.add_column("embedding_provider", sa.Column("name", sa.String(50), nullable=True))
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("default_model_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Drop the existing primary key constraint on provider_type
|
||||
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
|
||||
|
||||
# Create the original primary key constraint on id
|
||||
op.create_primary_key("embedding_provider_pkey", "embedding_provider", ["id"])
|
||||
|
||||
# Update name with existing provider_type values
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_provider
|
||||
SET name = CASE
|
||||
WHEN provider_type = 'OPENAI' THEN 'OpenAI'
|
||||
WHEN provider_type = 'COHERE' THEN 'Cohere'
|
||||
WHEN provider_type = 'GOOGLE' THEN 'Google'
|
||||
WHEN provider_type = 'VOYAGE' THEN 'Voyage'
|
||||
ELSE provider_type
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the provider_type column from embedding_provider
|
||||
op.drop_column("embedding_provider", "provider_type")
|
||||
|
||||
# Recreate the foreign key constraint in embedding_model table
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["cloud_provider_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Recreate the foreign key constraint in embedding_model table
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_provider_default_model",
|
||||
"embedding_provider",
|
||||
"embedding_model",
|
||||
["default_model_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""add has_web_login column to user
|
||||
|
||||
Revision ID: f7e58d357687
|
||||
Revises: bceb1e139447
|
||||
Create Date: 2024-09-07 20:20:54.522620
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7e58d357687"
|
||||
down_revision = "ba98eba0f66a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "has_web_login")
|
||||
@@ -3,49 +3,24 @@ from sqlalchemy.orm import Session
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_access_info_for_document
|
||||
from danswer.db.document import get_access_info_for_documents
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
def _get_access_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DocumentAccess:
|
||||
info = get_access_info_for_document(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
if not info:
|
||||
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
|
||||
return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2])
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DocumentAccess:
|
||||
versioned_get_access_for_document_fn = fetch_versioned_implementation(
|
||||
"danswer.access.access", "_get_access_for_document"
|
||||
)
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
document_access_info = get_access_info_for_documents(
|
||||
document_access_info = get_acccess_info_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(
|
||||
user_ids=user_ids, user_groups=[], is_public=is_public
|
||||
)
|
||||
document_id: DocumentAccess.build(user_ids, [], is_public)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
}
|
||||
|
||||
@@ -53,13 +28,14 @@ def _get_access_for_documents(
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""Fetches all access information for the given documents."""
|
||||
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
|
||||
"danswer.access.access", "_get_access_for_documents"
|
||||
)
|
||||
return versioned_get_access_for_documents_fn(
|
||||
document_ids, db_session
|
||||
document_ids, db_session, cc_pair_to_delete
|
||||
) # type: ignore
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_USER_STORE_KEY
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
|
||||
USER_STORE_KEY = "INVITED_USERS"
|
||||
|
||||
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_dynamic_config_store()
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
return cast(list, store.load(USER_STORE_KEY))
|
||||
except ConfigNotFoundError:
|
||||
return list()
|
||||
|
||||
|
||||
def write_invited_users(emails: list[str]) -> int:
|
||||
store = get_dynamic_config_store()
|
||||
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
store.store(USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
@@ -3,27 +3,29 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
|
||||
from danswer.dynamic_configs.store import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.store import DynamicConfigStore
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.server.manage.models import UserPreferences
|
||||
|
||||
|
||||
NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
|
||||
|
||||
|
||||
def set_no_auth_user_preferences(
|
||||
store: DynamicConfigStore, preferences: UserPreferences
|
||||
) -> None:
|
||||
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
|
||||
store.store(NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
|
||||
|
||||
|
||||
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
|
||||
try:
|
||||
preferences_data = cast(
|
||||
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
|
||||
Mapping[str, Any], store.load(NO_AUTH_USER_PREFERENCES_KEY)
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except ConfigNotFoundError:
|
||||
return UserPreferences(chosen_assistants=None, default_model=None)
|
||||
return UserPreferences(chosen_assistants=None)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
|
||||
|
||||
@@ -5,20 +5,8 @@ from fastapi_users import schemas
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""
|
||||
User roles
|
||||
- Basic can't perform any admin actions
|
||||
- Admin can perform all admin actions
|
||||
- Curator can perform admin actions for
|
||||
groups they are curators of
|
||||
- Global Curator can perform admin actions
|
||||
for all groups they are a member of
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
ADMIN = "admin"
|
||||
CURATOR = "curator"
|
||||
GLOBAL_CURATOR = "global_curator"
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
@@ -33,9 +21,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
has_web_login: bool | None = True
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
has_web_login: bool | None = True
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from email_validator import EmailNotValidError
|
||||
from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import exceptions
|
||||
from fastapi_users import FastAPIUsers
|
||||
from fastapi_users import models
|
||||
from fastapi_users import schemas
|
||||
@@ -35,7 +29,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
@@ -45,7 +38,6 @@ from danswer.configs.app_configs import SMTP_PASS
|
||||
from danswer.configs.app_configs import SMTP_PORT
|
||||
from danswer.configs.app_configs import SMTP_SERVER
|
||||
from danswer.configs.app_configs import SMTP_USER
|
||||
from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
@@ -58,33 +50,26 @@ from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
if user and user.role == UserRole.ADMIN:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
|
||||
raise ValueError(
|
||||
"User must choose a valid user authentication method: "
|
||||
"disabled, basic, or google_oauth"
|
||||
)
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
@@ -107,36 +92,10 @@ def user_needs_to_be_verified() -> bool:
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
return
|
||||
|
||||
if not email:
|
||||
raise PermissionError("Email must be specified")
|
||||
|
||||
email_info = validate_email(email) # can raise EmailNotValidError
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
# normalized emails are now being inserted into the db
|
||||
# we can remove this normalization on read after some time has passed
|
||||
email_info_whitelist = validate_email(email_whitelist)
|
||||
except EmailNotValidError:
|
||||
continue
|
||||
|
||||
# oddly, normalization does not include lowercasing the user part of the
|
||||
# email address ... which we want to allow
|
||||
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
|
||||
return
|
||||
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
whitelist = get_invited_users()
|
||||
if (whitelist and email not in whitelist) or not email:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
@@ -187,8 +146,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create: schemas.UC | UserCreate,
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
verify_email_is_invited(user_create.email)
|
||||
) -> models.UP:
|
||||
verify_email_in_whitelist(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
@@ -196,27 +155,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
return user
|
||||
return await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
|
||||
async def oauth_callback(
|
||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||
@@ -234,7 +173,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
verify_email_in_whitelist(account_email)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
user = await super().oauth_callback( # type: ignore
|
||||
return await super().oauth_callback( # type: ignore
|
||||
oauth_name=oauth_name,
|
||||
access_token=access_token,
|
||||
account_id=account_id,
|
||||
@@ -246,35 +185,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": None})
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.has_web_login:
|
||||
await self.user_db.update(
|
||||
user,
|
||||
update_dict={
|
||||
"is_verified": is_verified_by_default,
|
||||
"has_web_login": True,
|
||||
},
|
||||
)
|
||||
user.is_verified = is_verified_by_default
|
||||
user.has_web_login = True
|
||||
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.notice(f"User {user.id} has registered.")
|
||||
logger.info(f"User {user.id} has registered.")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.SIGN_UP,
|
||||
data={"action": "create"},
|
||||
@@ -284,35 +198,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
logger.info(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
verify_email_domain(user.email)
|
||||
|
||||
logger.notice(
|
||||
logger.info(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}"
|
||||
)
|
||||
|
||||
send_user_verification_email(user.email, token)
|
||||
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[User]:
|
||||
user = await super().authenticate(credentials)
|
||||
if user is None:
|
||||
try:
|
||||
user = await self.get_by_email(credentials.username)
|
||||
if not user.has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
pass
|
||||
return user
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
@@ -329,12 +227,10 @@ cookie_transport = CookieTransport(
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
strategy = DatabaseStrategy(
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="database",
|
||||
@@ -415,7 +311,6 @@ async def optional_user(
|
||||
async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
include_expired: bool = False,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return None
|
||||
@@ -432,53 +327,15 @@ async def double_check_user(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_with_expired_token(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(user, include_expired=True)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_curator_or_admin_user(
|
||||
user: User | None = Depends(current_user),
|
||||
) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated or lacks role information.",
|
||||
)
|
||||
|
||||
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
|
||||
if user.role not in allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not a curator or admin.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
@@ -486,12 +343,6 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
detail="Access denied. User is not an admin.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Danswer MIT
|
||||
return []
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,299 +0,0 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
PREFIX = "base"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int):
|
||||
self._id: int = id
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
return f"{self.PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def fence_key(self) -> str:
|
||||
# example: documentset_fence_1
|
||||
return f"{self.FENCE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
# example: documentset_taskset_1
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[2])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(self._id)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(self._id)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
|
||||
@classmethod
|
||||
def get_taskset_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.TASKSET_PREFIX
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
"""Notice that this is intentionally reusing the same taskset for all
|
||||
connector syncs"""
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
used to implement task prioritization.
|
||||
This operation is not atomic."""
|
||||
total_length = 0
|
||||
for i in range(len(DanswerCeleryPriority)):
|
||||
queue_name = queue
|
||||
if i > 0:
|
||||
queue_name += CELERY_SEPARATOR
|
||||
queue_name += str(i)
|
||||
|
||||
length = r.llen(queue_name)
|
||||
total_length += cast(int, length)
|
||||
|
||||
return total_length
|
||||
@@ -5,8 +5,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
@@ -15,13 +16,10 @@ from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import get_latest_task_by_type
|
||||
@@ -31,51 +29,36 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_deletion_status(
|
||||
def get_deletion_status(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> TaskQueueState | None:
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
cleanup_task_name = name_cc_cleanup_task(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
)
|
||||
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
|
||||
task_state = get_latest_task(task_name=cleanup_task_name, db_session=db_session)
|
||||
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
|
||||
if not deletion_task:
|
||||
if not task_state:
|
||||
return None
|
||||
|
||||
return DeletionAttemptSnapshot(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
status=deletion_task.status,
|
||||
status=task_state.status,
|
||||
)
|
||||
|
||||
|
||||
def should_kick_off_deletion_of_cc_pair(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
|
||||
if document_set.is_up_to_date:
|
||||
return False
|
||||
|
||||
if check_deletion_attempt_is_allowed(cc_pair, db_session):
|
||||
return False
|
||||
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if deletion_task and check_task_is_live_and_not_timed_out(
|
||||
deletion_task,
|
||||
db_session,
|
||||
# 1 hour timeout
|
||||
timeout=60 * 60,
|
||||
):
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
|
||||
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
|
||||
return False
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
return True
|
||||
|
||||
|
||||
@@ -97,7 +80,7 @@ def should_prune_cc_pair(
|
||||
return True
|
||||
return False
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
if PREVENT_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
@@ -106,9 +89,11 @@ def should_prune_cc_pair(
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
logger.info("Another Connector is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
from danswer.configs.app_configs import REDIS_SSL
|
||||
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
|
||||
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
|
||||
CELERY_SEPARATOR = ":"
|
||||
|
||||
CELERY_PASSWORD_PART = ""
|
||||
if REDIS_PASSWORD:
|
||||
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
|
||||
|
||||
REDIS_SCHEME = "redis"
|
||||
|
||||
# SSL-specific query parameters for Redis URL
|
||||
SSL_QUERY_PARAMS = ""
|
||||
if REDIS_SSL:
|
||||
REDIS_SCHEME = "rediss"
|
||||
SSL_QUERY_PARAMS = f"?ssl_cert_reqs={REDIS_SSL_CERT_REQS}"
|
||||
if REDIS_SSL_CA_CERTS:
|
||||
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
|
||||
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
|
||||
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
|
||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
||||
# can stall other tasks.
|
||||
worker_prefetch_multiplier = 4
|
||||
|
||||
broker_transport_options = {
|
||||
"priority_steps": list(range(len(DanswerCeleryPriority))),
|
||||
"sep": CELERY_SEPARATOR,
|
||||
"queue_order_strategy": "priority",
|
||||
}
|
||||
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
@@ -10,6 +10,8 @@ are multiple connector / credential pairs that have indexed it
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
@@ -22,8 +24,10 @@ from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document_connector_cnts
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.document_set import (
|
||||
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit,
|
||||
)
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
@@ -31,10 +35,6 @@ from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -78,37 +78,25 @@ def delete_connector_credential_pair_batch(
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt > 1
|
||||
]
|
||||
|
||||
# maps document id to list of document set names
|
||||
new_doc_sets_for_documents: dict[str, set[str]] = {
|
||||
document_id_and_document_set_names_tuple[0]: set(
|
||||
document_id_and_document_set_names_tuple[1]
|
||||
)
|
||||
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
)
|
||||
}
|
||||
|
||||
# determine future ACLs for documents in batch
|
||||
access_for_documents = get_access_for_documents(
|
||||
document_ids=document_ids_to_update,
|
||||
db_session=db_session,
|
||||
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=access,
|
||||
document_sets=new_doc_sets_for_documents[document_id],
|
||||
)
|
||||
for document_id, access in access_for_documents.items()
|
||||
]
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
# clean up Postgres
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
@@ -120,6 +108,48 @@ def delete_connector_credential_pair_batch(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cleanup_synced_entities(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> None:
|
||||
"""Updates the document sets associated with the connector / credential pair,
|
||||
then relies on the document set sync script to kick off Celery jobs which will
|
||||
sync these updates to Vespa.
|
||||
|
||||
Waits until the document sets are synced before returning."""
|
||||
logger.info(f"Cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'")
|
||||
document_sets_ids_to_sync = list(
|
||||
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# wait till all document sets are synced before continuing
|
||||
while True:
|
||||
all_synced = True
|
||||
document_sets = get_document_sets_by_ids(
|
||||
db_session=db_session, document_set_ids=document_sets_ids_to_sync
|
||||
)
|
||||
for document_set in document_sets:
|
||||
if not document_set.is_up_to_date:
|
||||
all_synced = False
|
||||
|
||||
if all_synced:
|
||||
break
|
||||
|
||||
# wait for 30 seconds before checking again
|
||||
db_session.commit() # end transaction
|
||||
logger.info(
|
||||
f"Document sets '{document_sets_ids_to_sync}' not synced yet, waiting 30s"
|
||||
)
|
||||
time.sleep(30)
|
||||
|
||||
logger.info(
|
||||
f"Finished cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'"
|
||||
)
|
||||
|
||||
|
||||
def delete_connector_credential_pair(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
@@ -147,32 +177,17 @@ def delete_connector_credential_pair(
|
||||
)
|
||||
num_docs_deleted += len(documents)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
# Clean up document sets / access information from Postgres
|
||||
# and sync these updates to Vespa
|
||||
# TODO: add user group cleanup with `fetch_versioned_implementation`
|
||||
cleanup_synced_entities(cc_pair, db_session)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
# clean up the rest of the related Postgres entities
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
@@ -184,11 +199,11 @@ def delete_connector_credential_pair(
|
||||
connector_id=connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
logger.info("Found no credentials left for connector, deleting connector")
|
||||
logger.debug("Found no credentials left for connector, deleting connector")
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
logger.notice(
|
||||
logger.info(
|
||||
"Successfully deleted connector_credential_pair with connector_id:"
|
||||
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
|
||||
)
|
||||
|
||||
@@ -41,12 +41,6 @@ def _initializer(
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _run_in_process(
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
_initializer(func, args, kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleJob:
|
||||
"""Drop in replacement for `dask.distributed.Future`"""
|
||||
@@ -119,7 +113,7 @@ class SimpleJobClient:
|
||||
job_id = self.job_id_counter
|
||||
self.job_id_counter += 1
|
||||
|
||||
process = Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||
process = Process(target=_initializer(func=func, args=args), daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
|
||||
@@ -7,21 +7,20 @@ from datetime import timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from danswer.background.indexing.tracer import DanswerTracer
|
||||
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from danswer.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from danswer.connectors.connector_runner import ConnectorRunner
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -36,15 +35,13 @@ from danswer.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
|
||||
|
||||
def _get_connector_runner(
|
||||
def _get_document_generator(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> ConnectorRunner:
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
|
||||
@@ -52,31 +49,43 @@ def _get_connector_runner(
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
task = attempt.connector_credential_pair.connector.input_type
|
||||
task = attempt.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector_credential_pair.connector.source,
|
||||
attempt.connector.source,
|
||||
task,
|
||||
attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
attempt.connector_credential_pair.credential,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential,
|
||||
db_session,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
# since we failed to even instantiate the connector, we pause the CCPair since
|
||||
# it will never succeed
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector_credential_pair.connector.id,
|
||||
credential_id=attempt.connector_credential_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.PAUSED,
|
||||
)
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
raise e
|
||||
|
||||
return ConnectorRunner(
|
||||
connector=runnable_connector, time_range=(start_time, end_time)
|
||||
)
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
)
|
||||
|
||||
logger.info(f"Polling for updates between {start_time} and {end_time}")
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time.timestamp(), end=end_time.timestamp()
|
||||
)
|
||||
|
||||
else:
|
||||
# Event types cannot be handled by a background type
|
||||
raise RuntimeError(f"Invalid task type: {task}")
|
||||
|
||||
return doc_batch_generator
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
@@ -89,63 +98,48 @@ def _run_indexing(
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
search_settings = index_attempt.search_settings
|
||||
index_name = search_settings.index_name
|
||||
db_embedding_model = index_attempt.embedding_model
|
||||
index_name = db_embedding_model.index_name
|
||||
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary = search_settings.status == IndexModelStatus.PRESENT
|
||||
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
|
||||
# Indexing is only done into one index at a time
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
embedding_model = DefaultIndexingEmbedder(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt.id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE),
|
||||
or (db_embedding_model.status == IndexModelStatus.FUTURE),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_cc_pair = index_attempt.connector_credential_pair
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
earliest_index_time = (
|
||||
db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0
|
||||
)
|
||||
|
||||
db_connector = index_attempt.connector
|
||||
db_credential = index_attempt.credential
|
||||
last_successful_index_time = (
|
||||
earliest_index_time
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
earliest_index=earliest_index_time,
|
||||
search_settings=index_attempt.search_settings,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
|
||||
tracer = DanswerTracer()
|
||||
tracer.start()
|
||||
tracer.snap()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
)
|
||||
|
||||
batch_num = 0
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
@@ -164,7 +158,7 @@ def _run_indexing(
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
doc_batch_generator = _get_document_generator(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
@@ -172,23 +166,15 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
tracer_counter = 0
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.snap()
|
||||
for doc_batch in connector_runner.run():
|
||||
for doc_batch in doc_batch_generator:
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
db_session.refresh(db_connector)
|
||||
if (
|
||||
(
|
||||
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and search_settings.status != IndexModelStatus.FUTURE
|
||||
)
|
||||
# if it's deleting, we don't care if this is a secondary index
|
||||
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
|
||||
db_connector.disabled
|
||||
and db_embedding_model.status != IndexModelStatus.FUTURE
|
||||
):
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
@@ -198,30 +184,17 @@ def _run_indexing(
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError("Index Attempt was canceled")
|
||||
|
||||
batch_description = []
|
||||
for doc in doc_batch:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
for section in doc.sections:
|
||||
doc_size += len(section.text)
|
||||
|
||||
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Document size: doc='{doc.to_short_descriptor()}' "
|
||||
f"size={doc_size} "
|
||||
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
|
||||
)
|
||||
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
logger.debug(
|
||||
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
documents=doc_batch,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
),
|
||||
)
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
@@ -243,17 +216,6 @@ def _run_indexing(
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
tracer_counter += 1
|
||||
if (
|
||||
INDEXING_TRACER_INTERVAL > 0
|
||||
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
|
||||
):
|
||||
logger.debug(
|
||||
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
|
||||
)
|
||||
tracer.snap()
|
||||
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
|
||||
run_end_dt = window_end
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
@@ -264,7 +226,7 @@ def _run_indexing(
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
logger.info(
|
||||
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
@@ -276,7 +238,7 @@ def _run_indexing(
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or db_connector.disabled
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
@@ -288,66 +250,17 @@ def _run_indexing(
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
break
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(
|
||||
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
|
||||
)
|
||||
tracer.snap()
|
||||
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
tracer.stop()
|
||||
logger.debug("Memory tracer stopped.")
|
||||
|
||||
if (
|
||||
index_attempt_md.num_exceptions > 0
|
||||
and index_attempt_md.num_exceptions >= batch_num
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt,
|
||||
db_session,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector_credential_pair.connector.id,
|
||||
credential_id=index_attempt.connector_credential_pair.credential.id,
|
||||
)
|
||||
raise Exception(
|
||||
f"Connector failed - All batches exceptioned: batches={batch_num}"
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt, db_session)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"exceptions={index_attempt_md.num_exceptions} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -356,6 +269,13 @@ def _run_indexing(
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexed or refreshed {document_count} total documents for a total of {chunk_count} indexed chunks"
|
||||
)
|
||||
logger.info(
|
||||
f"Connector successfully finished, elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
|
||||
|
||||
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
|
||||
# make sure that the index attempt can't change in between checking the
|
||||
@@ -379,27 +299,24 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
)
|
||||
|
||||
# only commit once, to make sure this all happens in a single transaction
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
mark_attempt_in_progress__no_commit(attempt)
|
||||
if attempt.embedding_model.status != IndexModelStatus.PRESENT:
|
||||
db_session.commit()
|
||||
|
||||
return attempt
|
||||
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
|
||||
) -> None:
|
||||
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
|
||||
try:
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# make sure that it is valid to run this indexing attempt + mark it
|
||||
@@ -407,19 +324,17 @@ def run_indexing_entrypoint(
|
||||
attempt = _prepare_index_attempt(db_session, index_attempt_id)
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
f"Running indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import tracemalloc
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DANSWER_TRACEMALLOC_FRAMES = 10
|
||||
|
||||
|
||||
class DanswerTracer:
|
||||
def __init__(self) -> None:
|
||||
self.snapshot_first: tracemalloc.Snapshot | None = None
|
||||
self.snapshot_prev: tracemalloc.Snapshot | None = None
|
||||
self.snapshot: tracemalloc.Snapshot | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
|
||||
|
||||
def stop(self) -> None:
|
||||
tracemalloc.stop()
|
||||
|
||||
def snap(self) -> None:
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
|
||||
snapshot = snapshot.filter_traces(
|
||||
(
|
||||
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
|
||||
tracemalloc.Filter(
|
||||
False, "<frozen importlib._bootstrap>"
|
||||
), # Exclude importlib
|
||||
tracemalloc.Filter(
|
||||
False, "<frozen importlib._bootstrap_external>"
|
||||
), # Exclude external importlib
|
||||
)
|
||||
)
|
||||
|
||||
if not self.snapshot_first:
|
||||
self.snapshot_first = snapshot
|
||||
|
||||
if self.snapshot:
|
||||
self.snapshot_prev = self.snapshot
|
||||
|
||||
self.snapshot = snapshot
|
||||
|
||||
def log_snapshot(self, numEntries: int) -> None:
|
||||
if not self.snapshot:
|
||||
return
|
||||
|
||||
stats = self.snapshot.statistics("traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.debug(f"Tracer snap: {s}")
|
||||
for line in s.traceback:
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
@staticmethod
|
||||
def log_diff(
|
||||
snap_current: tracemalloc.Snapshot,
|
||||
snap_previous: tracemalloc.Snapshot,
|
||||
numEntries: int,
|
||||
) -> None:
|
||||
stats = snap_current.compare_to(snap_previous, "traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.debug(f"Tracer diff: {s}")
|
||||
for line in s.traceback.format():
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
def log_previous_diff(self, numEntries: int) -> None:
|
||||
if not self.snapshot or not self.snapshot_prev:
|
||||
return
|
||||
|
||||
DanswerTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
|
||||
|
||||
def log_first_diff(self, numEntries: int) -> None:
|
||||
if not self.snapshot or not self.snapshot_first:
|
||||
return
|
||||
|
||||
DanswerTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)
|
||||
@@ -93,16 +93,9 @@ def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA
|
||||
kwargs_for_build_name = kwargs or {}
|
||||
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# register_task must come before fn = apply_async or else the task
|
||||
# might run mark_task_start (and crash) before the task row exists
|
||||
db_task = register_task(task_name, db_session)
|
||||
|
||||
# mark the task as started
|
||||
task = fn(args, kwargs, *other_args, **other_kwargs)
|
||||
|
||||
# we update the celery task id for diagnostic purposes
|
||||
# but it isn't currently used by any code
|
||||
db_task.task_id = task.id
|
||||
db_session.commit()
|
||||
register_task(task.id, task_name, db_session)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@@ -16,30 +16,24 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
@@ -59,68 +53,41 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
|
||||
|
||||
def _should_create_new_indexing(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
connector: Connector,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
model: EmbeddingModel,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
connector = cc_pair.connector
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
if model.status == IndexModelStatus.FUTURE and not last_index:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if not cc_pair.status.is_active() or connector.id == 0:
|
||||
# If the connector is disabled, don't index
|
||||
# NOTE: during an embedding model switch over, we ignore this
|
||||
# and index the disabled connectors as well (which is why this if
|
||||
# statement is below the first condition above)
|
||||
if connector.disabled:
|
||||
return False
|
||||
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
# Only one scheduled/ongoing job per connector at a time
|
||||
# this prevents cases where
|
||||
# (1) the "latest" index_attempt is scheduled so we show
|
||||
# that in the UI despite another index_attempt being in-progress
|
||||
# (2) multiple scheduled index_attempts at a time
|
||||
if (
|
||||
last_index.status == IndexingStatus.NOT_STARTED
|
||||
or last_index.status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
# Only one scheduled job per connector at a time
|
||||
# Can schedule another one if the current one is already running however
|
||||
# Because the currently running one will not be until the latest time
|
||||
# Note, this last index is for the given embedding model
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
@@ -128,14 +95,24 @@ def _should_create_new_indexing(
|
||||
return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
|
||||
|
||||
def _is_indexing_job_marked_as_finished(index_attempt: IndexAttempt | None) -> bool:
|
||||
if index_attempt is None:
|
||||
return False
|
||||
|
||||
return (
|
||||
index_attempt.status == IndexingStatus.FAILED
|
||||
or index_attempt.status == IndexingStatus.SUCCESS
|
||||
)
|
||||
|
||||
|
||||
def _mark_run_failed(
|
||||
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
) -> None:
|
||||
"""Marks the `index_attempt` row as failed + updates the `
|
||||
connector_credential_pair` to reflect that the run failed"""
|
||||
logger.warning(
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
|
||||
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
@@ -154,7 +131,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
ongoing: set[tuple[int | None, int]] = set()
|
||||
ongoing: set[tuple[int | None, int | None, int]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
@@ -167,43 +144,42 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
continue
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.search_settings_id,
|
||||
attempt.connector_id,
|
||||
attempt.credential_id,
|
||||
attempt.embedding_model_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
embedding_models = [get_current_db_embedding_model(db_session)]
|
||||
secondary_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
if secondary_embedding_model is not None:
|
||||
embedding_models.append(secondary_embedding_model)
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
all_connectors = fetch_connectors(db_session)
|
||||
for connector in all_connectors:
|
||||
for association in connector.credentials:
|
||||
for model in embedding_models:
|
||||
credential = association.credential
|
||||
|
||||
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in all_connector_credential_pairs:
|
||||
for search_settings_instance in search_settings:
|
||||
# Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
if (cc_pair.id, search_settings_instance.id) in ongoing:
|
||||
continue
|
||||
# Check if there is an ongoing indexing attempt for this connector + credential pair
|
||||
if (connector.id, credential.id, model.id) in ongoing:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
last_attempt = get_last_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
connector=connector,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
create_index_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
@@ -220,12 +196,10 @@ def cleanup_indexing_jobs(
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
if not job.done() and not _is_indexing_job_marked_as_finished(
|
||||
index_attempt
|
||||
):
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
@@ -297,28 +271,24 @@ def kickoff_indexing_jobs(
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
with Session(engine) as db_session:
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.search_settings)
|
||||
(attempt, attempt.embedding_model)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
if attempt.id not in existing_jobs
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
|
||||
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
|
||||
|
||||
if not new_indexing_attempts:
|
||||
return existing_jobs
|
||||
|
||||
indexing_attempt_count = 0
|
||||
|
||||
for attempt, search_settings in new_indexing_attempts:
|
||||
for attempt, embedding_model in new_indexing_attempts:
|
||||
use_secondary_index = (
|
||||
search_settings.status == IndexModelStatus.FUTURE
|
||||
if search_settings is not None
|
||||
embedding_model.status == IndexModelStatus.FUTURE
|
||||
if embedding_model is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector_credential_pair.connector is None:
|
||||
if attempt.connector is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
)
|
||||
@@ -327,7 +297,7 @@ def kickoff_indexing_jobs(
|
||||
attempt, db_session, failure_reason="Connector is null"
|
||||
)
|
||||
continue
|
||||
if attempt.connector_credential_pair.credential is None:
|
||||
if attempt.credential is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
)
|
||||
@@ -341,7 +311,6 @@ def kickoff_indexing_jobs(
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
@@ -349,63 +318,40 @@ def kickoff_indexing_jobs(
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
|
||||
if run:
|
||||
if indexing_attempt_count == 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
|
||||
)
|
||||
|
||||
indexing_attempt_count += 1
|
||||
secondary_str = " (secondary index)" if use_secondary_index else ""
|
||||
secondary_str = "(secondary index) " if use_secondary_index else ""
|
||||
logger.info(
|
||||
f"Indexing dispatched{secondary_str}: "
|
||||
f"attempt_id={attempt.id} "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.credential_id}'"
|
||||
f"Kicked off {secondary_str}"
|
||||
f"indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
existing_jobs_copy[attempt.id] = run
|
||||
|
||||
if indexing_attempt_count > 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch results: "
|
||||
f"initial_pending={len(new_indexing_attempts)} "
|
||||
f"started={indexing_attempt_count} "
|
||||
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def update_loop(
|
||||
delay: int = 10,
|
||||
num_workers: int = NUM_INDEXING_WORKERS,
|
||||
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
) -> None:
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
@@ -420,7 +366,7 @@ def update_loop(
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
cluster_secondary = LocalCluster(
|
||||
n_workers=num_secondary_workers,
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
@@ -430,18 +376,18 @@ def update_loop(
|
||||
client_primary.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||
logger.info(f"Running update, current UTC time: {start_time_utc}")
|
||||
|
||||
if existing_jobs:
|
||||
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"Found existing indexing jobs: "
|
||||
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
||||
)
|
||||
@@ -465,9 +411,8 @@ def update_loop(
|
||||
|
||||
def update__main() -> None:
|
||||
set_is_ee_based_on_env_variable()
|
||||
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
|
||||
|
||||
logger.notice("Starting indexing service")
|
||||
logger.info("Starting Indexing Loop")
|
||||
update_loop()
|
||||
|
||||
|
||||
|
||||
@@ -35,19 +35,14 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
# Optional id at which we finish processing
|
||||
stop_at_message_id: int | None = None,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
|
||||
all_chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
prefetch_tool_calls=prefetch_tool_calls,
|
||||
)
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
@@ -63,12 +58,7 @@ def create_chat_chain(
|
||||
current_message: ChatMessage | None = root_message
|
||||
while current_message is not None:
|
||||
child_msg = current_message.latest_child_message
|
||||
|
||||
# Break if at the end of the chain
|
||||
# or have reached the `final_id` of the submitted message
|
||||
if not child_msg or (
|
||||
stop_at_message_id and current_message.id == stop_at_message_id
|
||||
):
|
||||
if not child_msg:
|
||||
break
|
||||
current_message = id_to_msg.get(child_msg)
|
||||
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
input_prompts:
|
||||
- id: -5
|
||||
prompt: "Elaborate"
|
||||
content: "Elaborate on the above, give me a more in depth explanation."
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -4
|
||||
prompt: "Reword"
|
||||
content: "Help me rewrite the following politely and concisely for professional communication:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -3
|
||||
prompt: "Email"
|
||||
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -2
|
||||
prompt: "Debug"
|
||||
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
@@ -1,17 +1,13 @@
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt as PromptDBModel
|
||||
from danswer.db.models import Tool as ToolDBModel
|
||||
from danswer.db.persona import get_prompt_by_name
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.persona import upsert_prompt
|
||||
@@ -80,31 +76,9 @@ def load_personas_from_yaml(
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
.filter(ToolDBModel.name == "ImageGenerationTool")
|
||||
.first()
|
||||
)
|
||||
if image_gen_tool:
|
||||
tool_ids.append(image_gen_tool.id)
|
||||
|
||||
llm_model_provider_override = persona.get("llm_model_provider_override")
|
||||
llm_model_version_override = persona.get("llm_model_version_override")
|
||||
|
||||
# Set specific overrides for image generation persona
|
||||
if persona.get("image_generation"):
|
||||
llm_model_version_override = "gpt-4o"
|
||||
|
||||
existing_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.name == persona["name"])
|
||||
.first()
|
||||
)
|
||||
|
||||
upsert_persona(
|
||||
user=None,
|
||||
# Negative to not conflict with existing personas
|
||||
persona_id=(-1 * p_id) if p_id is not None else None,
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
@@ -114,52 +88,20 @@ def load_personas_from_yaml(
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
tool_ids=tool_ids,
|
||||
default_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority"),
|
||||
is_visible=existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible"),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
|
||||
with open(input_prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_input_prompts = data.get("input_prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(prompt_yaml)
|
||||
load_personas_from_yaml(personas_yaml)
|
||||
load_input_prompts_from_yaml(input_prompts_yaml)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -10,7 +9,6 @@ from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
@@ -36,53 +34,27 @@ class QADocsResponse(RetrievalDocs):
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
initial_dict["applied_time_cutoff"] = (
|
||||
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
||||
)
|
||||
|
||||
return initial_dict
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
data["stop_reason"] = self.stop_reason.name
|
||||
return data
|
||||
|
||||
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int]
|
||||
relevant_chunk_indices: list[int]
|
||||
|
||||
|
||||
class FinalUsedContextDocsResponse(BaseModel):
|
||||
final_context_docs: list[LlmDoc]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
relevant: bool
|
||||
class RelevanceChunk(BaseModel):
|
||||
# TODO make this document level. Also slight misnomer here as this is actually
|
||||
# done at the section level currently rather than the chunk
|
||||
relevant: bool | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class SectionRelevancePiece(RelevanceAnalysis):
|
||||
"""LLM analysis mapped to an Inference Section"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int # ID of the center chunk for a given inference section
|
||||
|
||||
|
||||
class DocumentRelevance(BaseModel):
|
||||
"""Contains all relevance information for a given search"""
|
||||
|
||||
relevance_summaries: dict[str, RelevanceAnalysis]
|
||||
class LLMRelevanceSummaryResponse(BaseModel):
|
||||
relevance_summaries: dict[str, RelevanceChunk]
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
@@ -97,24 +69,8 @@ class CitationInfo(BaseModel):
|
||||
document_id: str
|
||||
|
||||
|
||||
class AllCitations(BaseModel):
|
||||
citations: list[CitationInfo]
|
||||
|
||||
|
||||
# This is a mapping of the citation number to the document index within
|
||||
# the result search doc set
|
||||
class MessageSpecificCitations(BaseModel):
|
||||
citation_map: dict[int, int]
|
||||
|
||||
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
@@ -152,7 +108,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
@@ -161,7 +117,7 @@ class ImageGenerationDisplay(BaseModel):
|
||||
|
||||
|
||||
class CustomToolResponse(BaseModel):
|
||||
response: ToolResultType
|
||||
response: dict
|
||||
tool_name: str
|
||||
|
||||
|
||||
@@ -173,7 +129,6 @@ AnswerQuestionPossibleReturn = (
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ personas:
|
||||
# this is for DanswerBot to use when tagged in a non-configured channel
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Knowledge"
|
||||
name: "Danswer"
|
||||
description: >
|
||||
Assistant with access to documents from your Connected Sources.
|
||||
# Default Prompt objects attached to the persona, see prompts.yaml
|
||||
@@ -17,7 +17,7 @@ personas:
|
||||
num_chunks: 10
|
||||
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
|
||||
# if the chunk is useful or not towards the latest user query
|
||||
# This feature can be overriden for all personas via DISABLE_LLM_DOC_RELEVANCE env variable
|
||||
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
|
||||
llm_relevance_filter: true
|
||||
# Enable/Disable usage of the LLM to extract query time filters including source type and time range filters
|
||||
llm_filter_extraction: true
|
||||
@@ -37,15 +37,12 @@ personas:
|
||||
# - "Engineer Onboarding"
|
||||
# - "Benefits"
|
||||
document_sets: []
|
||||
icon_shape: 23013
|
||||
icon_color: "#6FB1FF"
|
||||
display_priority: 1
|
||||
is_visible: true
|
||||
|
||||
|
||||
- id: 1
|
||||
name: "General"
|
||||
name: "GPT"
|
||||
description: >
|
||||
Assistant with no access to documents. Chat with just the Large Language Model.
|
||||
Assistant with no access to documents. Chat with just the Language Model.
|
||||
prompts:
|
||||
- "OnlyLLM"
|
||||
num_chunks: 0
|
||||
@@ -53,10 +50,7 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 50910
|
||||
icon_color: "#FF6F6F"
|
||||
display_priority: 0
|
||||
is_visible: true
|
||||
|
||||
|
||||
- id: 2
|
||||
name: "Paraphrase"
|
||||
@@ -69,25 +63,3 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 45519
|
||||
icon_color: "#6FFF8D"
|
||||
display_priority: 2
|
||||
is_visible: false
|
||||
|
||||
|
||||
- id: 3
|
||||
name: "Art"
|
||||
description: >
|
||||
Assistant for generating images based on descriptions.
|
||||
prompts:
|
||||
- "ImageGeneration"
|
||||
num_chunks: 0
|
||||
llm_relevance_filter: false
|
||||
llm_filter_extraction: false
|
||||
recency_bias: "no_decay"
|
||||
document_sets: []
|
||||
icon_shape: 234124
|
||||
icon_color: "#9B59B6"
|
||||
image_generation: true
|
||||
display_priority: 3
|
||||
is_visible: true
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from functools import partial
|
||||
@@ -7,15 +6,11 @@ from typing import cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
@@ -32,16 +27,15 @@ from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_db_search_doc_by_id
|
||||
from danswer.db.chat import get_doc_query_identifiers_from_model
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
@@ -57,9 +51,7 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
@@ -68,7 +60,6 @@ from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
@@ -88,8 +79,6 @@ from danswer.tools.internet_search.internet_search_tool import (
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
@@ -105,9 +94,9 @@ from danswer.utils.timing import log_generator_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _translate_citations(
|
||||
def translate_citations(
|
||||
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
||||
) -> MessageSpecificCitations:
|
||||
) -> dict[int, int]:
|
||||
"""Always cites the first instance of the document_id, assumes the db_docs
|
||||
are sorted in the order displayed in the UI"""
|
||||
doc_id_to_saved_doc_id_map: dict[str, int] = {}
|
||||
@@ -122,7 +111,7 @@ def _translate_citations(
|
||||
citation.citation_num
|
||||
] = doc_id_to_saved_doc_id_map[citation.document_id]
|
||||
|
||||
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)
|
||||
return citation_to_saved_doc_id_map
|
||||
|
||||
|
||||
def _handle_search_tool_response_summary(
|
||||
@@ -189,7 +178,7 @@ def _handle_internet_search_tool_response_summary(
|
||||
rephrased_query=internet_search_response.revised_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=SearchType.SEMANTIC,
|
||||
predicted_search=SearchType.HYBRID,
|
||||
applied_source_filters=[],
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=1.0,
|
||||
@@ -198,61 +187,48 @@ def _handle_internet_search_tool_response_summary(
|
||||
)
|
||||
|
||||
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
|
||||
) -> ForceUseTool:
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
)
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
|
||||
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
# Currently, the internet search tool does not support query override
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override and tool_name == SearchTool._NAME
|
||||
else None
|
||||
)
|
||||
|
||||
def _check_should_force_search(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
) -> ForceUseTool | None:
|
||||
# If files are already provided, don't run the search tool
|
||||
if new_msg_req.file_descriptors:
|
||||
# If user has uploaded files they're using, don't run any of the search tools
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name)
|
||||
return None
|
||||
|
||||
should_force_search = any(
|
||||
[
|
||||
if (
|
||||
new_msg_req.query_override
|
||||
or (
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
]
|
||||
)
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
)
|
||||
or new_msg_req.search_doc_ids
|
||||
or DISABLE_LLM_CHOOSE_SEARCH
|
||||
):
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override
|
||||
else None
|
||||
)
|
||||
# if we are using selected docs, just put something here so the Tool doesn't need
|
||||
# to build its own args via an LLM call
|
||||
if new_msg_req.search_doc_ids:
|
||||
args = {"query": new_msg_req.message}
|
||||
|
||||
if should_force_search:
|
||||
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
|
||||
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
|
||||
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
|
||||
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool._NAME,
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
| DanswerAnswerPiece
|
||||
| AllCitations
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -268,21 +244,17 @@ def stream_chat_message_objects(
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
# user message (e.g. this can only be used for the chat-seeding flow).
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
# Currently surrounding context is not supported for chat
|
||||
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
||||
new_msg_req.chunks_above = 0
|
||||
new_msg_req.chunks_below = 0
|
||||
|
||||
"""
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
@@ -302,10 +274,7 @@ def stream_chat_message_objects(
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
alternate_assistant_id, user=user, db_session=db_session
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
@@ -328,20 +297,14 @@ def stream_chat_message_objects(
|
||||
except GenAIDisabledException:
|
||||
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
|
||||
|
||||
llm_provider = llm.config.model_provider
|
||||
llm_model_name = llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
@@ -359,15 +322,7 @@ def stream_chat_message_objects(
|
||||
parent_message = root_message
|
||||
|
||||
user_message = None
|
||||
|
||||
if new_msg_req.regenerate:
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
stop_at_message_id=parent_id,
|
||||
chat_session_id=chat_session_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
elif not use_existing_user_message:
|
||||
if not use_existing_user_message:
|
||||
# Create new message at the right place in the tree and update the parent's child pointer
|
||||
# Don't commit yet until we verify the chat message chain
|
||||
user_message = create_new_chat_message(
|
||||
@@ -406,14 +361,6 @@ def stream_chat_message_objects(
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
# leads to worst search quality
|
||||
if not history_msgs:
|
||||
new_msg_req.query_override = (
|
||||
new_msg_req.query_override or new_msg_req.message
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
@@ -473,23 +420,9 @@ def stream_chat_message_objects(
|
||||
else default_num_chunks
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
use_sections=new_msg_req.chunks_above > 0
|
||||
or new_msg_req.chunks_below > 0,
|
||||
)
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id if user_message else None,
|
||||
reserved_assistant_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
overridden_model = (
|
||||
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
|
||||
)
|
||||
|
||||
# Cannot determine these without the LLM step or breaking out early
|
||||
partial_response = partial(
|
||||
@@ -497,7 +430,6 @@ def stream_chat_message_objects(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=final_msg,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
# rephrased_query=,
|
||||
# token_count=,
|
||||
@@ -544,9 +476,6 @@ def stream_chat_message_objects(
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
@@ -606,11 +535,7 @@ def stream_chat_message_objects(
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
db_tool_model.openapi_schema
|
||||
),
|
||||
)
|
||||
|
||||
@@ -619,16 +544,13 @@ def stream_chat_message_objects(
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
tools, llm_tokenizer
|
||||
)
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm_provider, llm_model_name
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=AnswerStyleConfig(
|
||||
@@ -654,7 +576,11 @@ def stream_chat_message_objects(
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
],
|
||||
tools=tools,
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
force_use_tool=(
|
||||
_check_should_force_search(new_msg_req)
|
||||
if search_tool and len(tools) == 1
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
@@ -662,7 +588,6 @@ def stream_chat_message_objects(
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
@@ -681,31 +606,17 @@ def stream_chat_message_objects(
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
chunk_indices = packet.response
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
if reference_db_search_docs is not None and dropped_indices:
|
||||
chunk_indices = drop_llm_indices(
|
||||
llm_indices=chunk_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
yield LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=chunk_indices
|
||||
)
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
@@ -742,30 +653,31 @@ def stream_chat_message_objects(
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to process chat message")
|
||||
|
||||
# Don't leak the API key
|
||||
error_msg = str(e)
|
||||
logger.exception(f"Failed to process chat message: {error_msg}")
|
||||
if llm.config.api_key and llm.config.api_key.lower() in error_msg.lower():
|
||||
error_msg = (
|
||||
f"LLM failed to respond. Invalid API "
|
||||
f"key error from '{llm.config.model_provider}'."
|
||||
)
|
||||
|
||||
stack_trace = traceback.format_exc()
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
yield StreamingError(error=error_msg)
|
||||
# Cancel the transaction so that no messages are saved
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
db_citations = None
|
||||
if reference_db_search_docs:
|
||||
message_specific_citations = _translate_citations(
|
||||
db_citations = translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
@@ -774,7 +686,6 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
@@ -782,9 +693,7 @@ def stream_chat_message_objects(
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None,
|
||||
citations=db_citations,
|
||||
error=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
@@ -797,8 +706,6 @@ def stream_chat_message_objects(
|
||||
if tool_result
|
||||
else [],
|
||||
)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
@@ -807,8 +714,7 @@ def stream_chat_message_objects(
|
||||
|
||||
yield msg_detail_response
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(error_msg)
|
||||
logger.exception(e)
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
yield StreamingError(error="Failed to parse LLM output")
|
||||
@@ -820,7 +726,6 @@ def stream_chat_message(
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as db_session:
|
||||
objects = stream_chat_message_objects(
|
||||
@@ -829,7 +734,6 @@ def stream_chat_message(
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
yield get_json_line(obj.dict())
|
||||
|
||||
@@ -30,23 +30,7 @@ prompts:
|
||||
# Prompts the LLM to include citations in the for [1], [2] etc.
|
||||
# which get parsed to match the passed in sources
|
||||
include_citations: true
|
||||
|
||||
- name: "ImageGeneration"
|
||||
description: "Generates images based on user prompts!"
|
||||
system: >
|
||||
You are an advanced image generation system capable of creating diverse and detailed images.
|
||||
|
||||
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
|
||||
|
||||
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
|
||||
task: >
|
||||
Generate an image based on the user's description.
|
||||
|
||||
Provide a detailed description of the generated image, including key elements, colors, and composition.
|
||||
|
||||
If the request is not possible or appropriate, explain why and suggest alternatives.
|
||||
datetime_aware: true
|
||||
include_citations: false
|
||||
|
||||
|
||||
- name: "OnlyLLM"
|
||||
description: "Chat directly with the LLM!"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -93,14 +93,6 @@ SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
# If set, Danswer will listen to the `expires_at` returned by the identity
|
||||
# provider (e.g. Okta, Google, etc.) and force the user to re-authenticate
|
||||
# after this time has elapsed. Disabled since by default many auth providers
|
||||
# have very short expiry times (e.g. 1 hour) which provide a poor user experience
|
||||
TRACK_EXTERNAL_IDP_EXPIRY = (
|
||||
os.environ.get("TRACK_EXTERNAL_IDP_EXPIRY", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# DB Configs
|
||||
@@ -126,7 +118,6 @@ try:
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
||||
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
@@ -138,31 +129,6 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
# defaults to False
|
||||
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
|
||||
|
||||
# recycle timeout in seconds
|
||||
POSTGRES_POOL_RECYCLE_DEFAULT = 60 * 20 # 20 minutes
|
||||
try:
|
||||
POSTGRES_POOL_RECYCLE = int(
|
||||
os.environ.get("POSTGRES_POOL_RECYCLE", POSTGRES_POOL_RECYCLE_DEFAULT)
|
||||
)
|
||||
except ValueError:
|
||||
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
|
||||
|
||||
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
# Used for general redis things
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
# Used by celery as broker and backend
|
||||
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15))
|
||||
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "CERT_NONE")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
@@ -215,8 +181,8 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
|
||||
]
|
||||
|
||||
# Avoid to get archived pages
|
||||
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
|
||||
CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Save pages labels as Danswer metadata tags
|
||||
@@ -225,16 +191,6 @@ CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Attachments exceeding this size will not be retrieved (in bytes)
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
# Attachments with more chars than this will not be indexed. This is to prevent extremely
|
||||
# large files from freezing indexing. 200,000 is ~100 google doc pages.
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -256,11 +212,10 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
PRUNING_DISABLED = -1
|
||||
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
||||
|
||||
ALLOW_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
PREVENT_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
@@ -293,39 +248,18 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
|
||||
# fairly large amount of memory in order to increase substantially, since
|
||||
# each worker loads the embedding models into memory.
|
||||
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
|
||||
NUM_SECONDARY_INDEXING_WORKERS = int(
|
||||
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
|
||||
)
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MULTIPASS_INDEXING = (
|
||||
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
|
||||
)
|
||||
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
MINI_CHUNK_SIZE = 150
|
||||
|
||||
# This is the number of regular chunks per large chunk
|
||||
LARGE_CHUNK_RATIO = 4
|
||||
|
||||
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
# Timeout to wait for job's last update before killing it, in hours
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
|
||||
|
||||
# The indexer will warn in the logs whenver a document exceeds this threshold (in bytes)
|
||||
INDEXING_SIZE_WARNING_THRESHOLD = int(
|
||||
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024)
|
||||
)
|
||||
|
||||
# during indexing, will log verbose memory diff stats every x batches and at the end.
|
||||
# 0 disables this behavior and is the default.
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
|
||||
|
||||
# During an indexing attempt, specifies the number of batches which are allowed to
|
||||
# exception without aborting the attempt.
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
@@ -353,10 +287,6 @@ LOG_VESPA_TIMING_INFORMATION = (
|
||||
os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true"
|
||||
)
|
||||
LOG_ENDPOINT_LATENCY = os.environ.get("LOG_ENDPOINT_LATENCY", "").lower() == "true"
|
||||
LOG_POSTGRES_LATENCY = os.environ.get("LOG_POSTGRES_LATENCY", "").lower() == "true"
|
||||
LOG_POSTGRES_CONN_COUNTS = (
|
||||
os.environ.get("LOG_POSTGRES_CONN_COUNTS", "").lower() == "true"
|
||||
)
|
||||
# Anonymous usage telemetry
|
||||
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -3,13 +3,12 @@ import os
|
||||
|
||||
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
# Used for LLM filtering and reranking
|
||||
# We want this to be approximately the number of results we want to show on the first page
|
||||
# It cannot be too large due to cost and latency implications
|
||||
NUM_POSTPROCESSED_RESULTS = 20
|
||||
NUM_RERANKED_RESULTS = 20
|
||||
|
||||
# May be less depending on model
|
||||
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
|
||||
@@ -31,9 +30,13 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
|
||||
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
|
||||
# Note this is not in any of the deployment configs yet
|
||||
# Currently only applies to search flow not chat
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 1)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_CHUNK_FILTER = (
|
||||
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
|
||||
)
|
||||
# Whether the LLM should be used to decide if a search would help given the chat history
|
||||
DISABLE_LLM_CHOOSE_SEARCH = (
|
||||
os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true"
|
||||
@@ -44,19 +47,22 @@ DISABLE_LLM_QUERY_REPHRASE = (
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Keyword Search Drop Stopwords
|
||||
# If user has changed the default model, would most likely be to use a multilingual
|
||||
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
||||
if os.environ.get("EDIT_KEYWORD_QUERY"):
|
||||
EDIT_KEYWORD_QUERY = os.environ.get("EDIT_KEYWORD_QUERY", "").lower() == "true"
|
||||
else:
|
||||
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
|
||||
HYBRID_ALPHA_KEYWORD = max(
|
||||
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
|
||||
)
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
|
||||
# Weighting factor between Title and Content of documents during search, 1 for completely
|
||||
# Title based. Default heavily favors Content because Title is also included at the top of
|
||||
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
|
||||
# if the title is separated out. Title is most of a "boost" than a separate field.
|
||||
TITLE_CONTENT_RATIO = max(
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10))
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
|
||||
)
|
||||
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
# For example "English,French,Spanish", be sure to use the "," separator
|
||||
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
|
||||
@@ -69,29 +75,22 @@ LANGUAGE_CHAT_NAMING_HINT = (
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
# Additionally, some LLM providers have strict rate limits which may prohibit
|
||||
# sending many API requests at once (as is done in agentic search).
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_DOC_RELEVANCE = (
|
||||
os.environ.get("DISABLE_LLM_DOC_RELEVANCE", "").lower() == "true"
|
||||
)
|
||||
DISABLE_AGENTIC_SEARCH = (
|
||||
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
|
||||
).lower() == "true"
|
||||
|
||||
|
||||
# Stops streaming answers back to the UI if this pattern is seen:
|
||||
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
# Set this to "true" to hard delete chats
|
||||
# This will make chats unviewable by admins after a user deletes them
|
||||
# As opposed to soft deleting them, which just hides them from non-admin users
|
||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
|
||||
# The backend logic for this being True isn't fully supported yet
|
||||
HARD_DELETE_CHATS = False
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
|
||||
# Enable in-house model for detecting connector-based filtering in queries
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
||||
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
||||
|
||||
@@ -1,7 +1,26 @@
|
||||
from enum import auto
|
||||
from enum import Enum
|
||||
|
||||
DOCUMENT_ID = "document_id"
|
||||
CHUNK_ID = "chunk_id"
|
||||
BLURB = "blurb"
|
||||
CONTENT = "content"
|
||||
SOURCE_TYPE = "source_type"
|
||||
SOURCE_LINKS = "source_links"
|
||||
SOURCE_LINK = "link"
|
||||
SEMANTIC_IDENTIFIER = "semantic_identifier"
|
||||
TITLE = "title"
|
||||
SKIP_TITLE_EMBEDDING = "skip_title"
|
||||
SECTION_CONTINUATION = "section_continuation"
|
||||
EMBEDDINGS = "embeddings"
|
||||
TITLE_EMBEDDING = "title_embedding"
|
||||
ALLOWED_USERS = "allowed_users"
|
||||
ACCESS_CONTROL_LIST = "access_control_list"
|
||||
DOCUMENT_SETS = "document_sets"
|
||||
TIME_FILTER = "time_filter"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
METADATA_SUFFIX = "metadata_suffix"
|
||||
MATCH_HIGHLIGHTS = "match_highlights"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
# are still useful as a search result but not for QA.
|
||||
@@ -9,11 +28,23 @@ IGNORE_FOR_QA = "ignore_for_qa"
|
||||
# NOTE: deprecated, only used for porting key from old system
|
||||
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
|
||||
PUBLIC_DOC_PAT = "PUBLIC"
|
||||
PUBLIC_DOCUMENT_SET = "__PUBLIC"
|
||||
QUOTE = "quote"
|
||||
BOOST = "boost"
|
||||
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
|
||||
PRIMARY_OWNERS = "primary_owners"
|
||||
SECONDARY_OWNERS = "secondary_owners"
|
||||
RECENCY_BIAS = "recency_bias"
|
||||
HIDDEN = "hidden"
|
||||
SCORE = "score"
|
||||
ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
|
||||
# For chunking/processing chunks
|
||||
MAX_CHUNK_TITLE_LEN = 1000
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
# For combining attributes, doesn't have to be unique/perfect to work
|
||||
@@ -29,40 +60,12 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
POSTGRES_CELERY_APP_NAME = "celery"
|
||||
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "danswerapikey.ai"
|
||||
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
|
||||
|
||||
# Key-Value store keys
|
||||
KV_REINDEX_KEY = "needs_reindexing"
|
||||
KV_SEARCH_SETTINGS = "search_settings"
|
||||
KV_USER_STORE_KEY = "INVITED_USERS"
|
||||
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
|
||||
KV_CRED_KEY = "credential_id_{}"
|
||||
KV_GMAIL_CRED_KEY = "gmail_app_credential"
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
|
||||
KV_GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
|
||||
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
|
||||
KV_SETTINGS_KEY = "danswer_settings"
|
||||
KV_CUSTOMER_UUID_KEY = "customer_uuid"
|
||||
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
|
||||
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
@@ -106,10 +109,6 @@ class DocumentSource(str, Enum):
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
REINDEX = "reindex"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
R2 = "r2"
|
||||
S3 = "s3"
|
||||
@@ -165,27 +164,3 @@ class FileOrigin(str, Enum):
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
|
||||
|
||||
|
||||
class DanswerCeleryQueues:
|
||||
VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator"
|
||||
VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator"
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
|
||||
|
||||
class DanswerRedisLocks:
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
|
||||
|
||||
class DanswerCeleryPriority(int, Enum):
|
||||
HIGHEST = 0
|
||||
HIGH = auto()
|
||||
MEDIUM = auto()
|
||||
LOW = auto()
|
||||
LOWEST = auto()
|
||||
|
||||
@@ -73,15 +73,3 @@ DANSWER_BOT_FEEDBACK_REMINDER = int(
|
||||
DANSWER_BOT_REPHRASE_MESSAGE = (
|
||||
os.environ.get("DANSWER_BOT_REPHRASE_MESSAGE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of
|
||||
# responses DanswerBot can send in a given time period.
|
||||
# Set to 0 to disable the limit.
|
||||
DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD = int(
|
||||
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD", "5000")
|
||||
)
|
||||
# DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS is the number
|
||||
# of seconds until the response limit is reset.
|
||||
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS = int(
|
||||
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS", "86400")
|
||||
)
|
||||
|
||||
@@ -12,15 +12,13 @@ import os
|
||||
# The useable models configured as below must be SentenceTransformer compatible
|
||||
# NOTE: DO NOT CHANGE SET THESE UNLESS YOU KNOW WHAT YOU ARE DOING
|
||||
# IDEALLY, YOU SHOULD CHANGE EMBEDDING MODELS VIA THE UI
|
||||
DEFAULT_DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1"
|
||||
DEFAULT_DOCUMENT_ENCODER_MODEL = "intfloat/e5-base-v2"
|
||||
DOCUMENT_ENCODER_MODEL = (
|
||||
os.environ.get("DOCUMENT_ENCODER_MODEL") or DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
)
|
||||
# If the below is changed, Vespa deployment must also be changed
|
||||
DOC_EMBEDDING_DIM = int(os.environ.get("DOC_EMBEDDING_DIM") or 768)
|
||||
# Model should be chosen with 512 context size, ideally don't change this
|
||||
# If multipass_indexing is enabled, the max context size would be set to
|
||||
# DOC_EMBEDDING_CONTEXT_SIZE * LARGE_CHUNK_RATIO
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
NORMALIZE_EMBEDDINGS = (
|
||||
os.environ.get("NORMALIZE_EMBEDDINGS") or "true"
|
||||
@@ -36,42 +34,53 @@ OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS = False
|
||||
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
|
||||
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
|
||||
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
|
||||
# User's set embedding batch size overrides the default encoding batch sizes
|
||||
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE") or 0) or None
|
||||
|
||||
BATCH_SIZE_ENCODE_CHUNKS = EMBEDDING_BATCH_SIZE or 8
|
||||
# don't send over too many chunks at once, as sending too many could cause timeouts
|
||||
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = EMBEDDING_BATCH_SIZE or 512
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
# For score display purposes, only way is to know the expected ranges
|
||||
CROSS_ENCODER_RANGE_MAX = 1
|
||||
CROSS_ENCODER_RANGE_MIN = 0
|
||||
|
||||
# Unused currently, can't be used with the current default encoder model due to its output range
|
||||
SEARCH_DISTANCE_CUTOFF = 0
|
||||
|
||||
|
||||
#####
|
||||
# Generative AI Model Configs
|
||||
#####
|
||||
|
||||
# NOTE: the 3 below should only be used for dev.
|
||||
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY")
|
||||
# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default,
|
||||
# be sure to use one that is LiteLLM compatible:
|
||||
# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables
|
||||
# The provider is the prefix before / in the model argument
|
||||
|
||||
# Additionally Danswer supports GPT4All and custom request library based models
|
||||
# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach
|
||||
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
|
||||
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
|
||||
# If using Azure, it's the engine name, for example: Danswer
|
||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION")
|
||||
|
||||
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
|
||||
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
|
||||
FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION")
|
||||
|
||||
# Override the auto-detection of LLM max context length
|
||||
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
|
||||
|
||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||
# This is the minimum token context we will leave for the LLM to generate an answer
|
||||
GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
||||
os.environ.get("GEN_AI_NUM_RESERVED_OUTPUT_TOKENS") or 1024
|
||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||
GEN_AI_API_KEY = (
|
||||
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None
|
||||
)
|
||||
|
||||
# Typically, GenAI models nowadays are at least 4K tokens
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
|
||||
|
||||
# API Base, such as (for Azure): https://danswer.openai.azure.com/
|
||||
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
|
||||
# API Version, such as (for Azure): 2023-09-15-preview
|
||||
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
|
||||
# LiteLLM custom_llm_provider
|
||||
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
|
||||
# Override the auto-detection of LLM max context length
|
||||
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
|
||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
||||
# Number of tokens from chat history to include at maximum
|
||||
# 3000 should be enough context regardless of use, no need to include as much as possible
|
||||
# as this drives up the cost unnecessarily
|
||||
|
||||
@@ -59,8 +59,6 @@ if __name__ == "__main__":
|
||||
latest_docs = test_connector.poll_source(one_day_ago, current)
|
||||
```
|
||||
|
||||
> Note: Be sure to set PYTHONPATH to danswer/backend before running the above main.
|
||||
|
||||
|
||||
### Additional Required Changes:
|
||||
#### Backend Changes
|
||||
@@ -70,16 +68,17 @@ if __name__ == "__main__":
|
||||
[here](https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/factory.py#L33)
|
||||
|
||||
#### Frontend Changes
|
||||
- Add the new Connector definition to the `SOURCE_METADATA_MAP` [here](https://github.com/danswer-ai/danswer/blob/main/web/src/lib/sources.ts#L59).
|
||||
- Add the definition for the new Form to the `connectorConfigs` object [here](https://github.com/danswer-ai/danswer/blob/main/web/src/lib/connectors/connectors.ts#L79).
|
||||
- Create the new connector directory and admin page under `danswer/web/src/app/admin/connectors/`
|
||||
- Create the new icon, type, source, and filter changes
|
||||
(refer to existing [PR](https://github.com/danswer-ai/danswer/pull/139))
|
||||
|
||||
#### Docs Changes
|
||||
Create the new connector page (with guiding images!) with how to get the connector credentials and how to set up the
|
||||
connector in Danswer. Then create a Pull Request in https://github.com/danswer-ai/danswer-docs.
|
||||
connector in Danswer. Then create a Pull Request in https://github.com/danswer-ai/danswer-docs
|
||||
|
||||
|
||||
### Before opening PR
|
||||
1. Be sure to fully test changes end to end with setting up the connector and updating the index with new docs from the
|
||||
new connector. To make it easier to review, please attach a video showing the successful creation of the connector via the UI (starting from the `Add Connector` page).
|
||||
2. Add a folder + tests under `backend/tests/daily/connectors` director. For an example, checkout the [test for Confluence](https://github.com/danswer-ai/danswer/blob/main/backend/tests/daily/connectors/confluence/test_confluence_basic.py). In the PR description, include a guide on how to setup the new source to pass the test. Before merging, we will re-create the environment and make sure the test(s) pass.
|
||||
3. Be sure to run the linting/formatting, refer to the formatting and linting section in
|
||||
new connector.
|
||||
2. Be sure to run the linting/formatting, refer to the formatting and linting section in
|
||||
[CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md#formatting-and-linting)
|
||||
|
||||
@@ -56,7 +56,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
Raises ValueError for unsupported bucket types.
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"Loading credentials for {self.bucket_name} or type {self.bucket_type}"
|
||||
)
|
||||
|
||||
@@ -169,7 +169,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
raise ConnectorMissingCredentialError("Blog storage")
|
||||
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
@@ -220,7 +220,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
yield batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
logger.debug("Loading blob objects")
|
||||
logger.info("Loading blob objects")
|
||||
return self._yield_blob_objects(
|
||||
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
end=datetime.now(timezone.utc),
|
||||
@@ -230,7 +230,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
raise ConnectorMissingCredentialError("Blog storage")
|
||||
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
@@ -7,16 +7,13 @@ from datetime import timezone
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
@@ -44,12 +41,77 @@ logger = setup_logger()
|
||||
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
|
||||
|
||||
NO_PERMISSIONS_TO_VIEW_ATTACHMENTS_ERROR_STR = (
|
||||
"User not permitted to view attachments on content"
|
||||
)
|
||||
NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
|
||||
"No parent or not permitted to view content with id"
|
||||
)
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview
|
||||
|
||||
wiki_base is https://danswer.atlassian.net/wiki
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
parsed_url.scheme
|
||||
+ "://"
|
||||
+ parsed_url.netloc
|
||||
+ parsed_url.path.split("/spaces")[0]
|
||||
)
|
||||
|
||||
path_parts = parsed_url.path.split("/")
|
||||
space = path_parts[3]
|
||||
|
||||
page_id = path_parts[5] if len(path_parts) > 5 else ""
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page https://danswer.ai/confluence/display/1234abcd/overview
|
||||
wiki_base is https://danswer.ai/confluence
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
# /display/ is always right before the space and at the end of the base print()
|
||||
DISPLAY = "/display/"
|
||||
PAGE = "/pages/"
|
||||
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
parsed_url.scheme
|
||||
+ "://"
|
||||
+ parsed_url.netloc
|
||||
+ parsed_url.path.split(DISPLAY)[0]
|
||||
)
|
||||
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
|
||||
page_id = ""
|
||||
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
|
||||
page_id = content[1]
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
|
||||
is_confluence_cloud = (
|
||||
".atlassian.net/wiki/spaces/" in wiki_url
|
||||
or ".jira.com/wiki/spaces/" in wiki_url
|
||||
)
|
||||
|
||||
try:
|
||||
if is_confluence_cloud:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(
|
||||
wiki_url
|
||||
)
|
||||
else:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return wiki_base, space, page_id, is_confluence_cloud
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@@ -137,56 +199,34 @@ def _comment_dfs(
|
||||
comments_str += "\nComment:\n" + parse_html_page(
|
||||
comment_html, confluence_client
|
||||
)
|
||||
try:
|
||||
child_comment_pages = get_page_child_by_type(
|
||||
comment_page["id"],
|
||||
type="comment",
|
||||
start=None,
|
||||
limit=None,
|
||||
expand="body.storage.value",
|
||||
)
|
||||
comments_str = _comment_dfs(
|
||||
comments_str, child_comment_pages, confluence_client
|
||||
)
|
||||
except HTTPError as e:
|
||||
# not the cleanest, but I'm not aware of a nicer way to check the error
|
||||
if NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR not in str(e):
|
||||
raise
|
||||
|
||||
child_comment_pages = get_page_child_by_type(
|
||||
comment_page["id"],
|
||||
type="comment",
|
||||
start=None,
|
||||
limit=None,
|
||||
expand="body.storage.value",
|
||||
)
|
||||
comments_str = _comment_dfs(
|
||||
comments_str, child_comment_pages, confluence_client
|
||||
)
|
||||
return comments_str
|
||||
|
||||
|
||||
def _datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime.fromisoformat(datetime_string)
|
||||
|
||||
if datetime_object.tzinfo is None:
|
||||
# If no timezone info, assume it is UTC
|
||||
datetime_object = datetime_object.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# If not in UTC, translate it
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
class RecursiveIndexer:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
confluence_client: Confluence,
|
||||
index_recursively: bool,
|
||||
index_origin: bool,
|
||||
origin_page_id: str,
|
||||
) -> None:
|
||||
self.batch_size = 1
|
||||
# batch_size
|
||||
self.confluence_client = confluence_client
|
||||
self.index_recursively = index_recursively
|
||||
self.index_origin = index_origin
|
||||
self.origin_page_id = origin_page_id
|
||||
self.pages = self.recurse_children_pages(0, self.origin_page_id)
|
||||
|
||||
def get_origin_page(self) -> list[dict[str, Any]]:
|
||||
return [self._fetch_origin_page()]
|
||||
|
||||
def get_pages(self, ind: int, size: int) -> list[dict]:
|
||||
if ind * size > len(self.pages):
|
||||
return []
|
||||
@@ -242,11 +282,12 @@ class RecursiveIndexer:
|
||||
current_level_pages = next_level_pages
|
||||
next_level_pages = []
|
||||
|
||||
try:
|
||||
origin_page = self._fetch_origin_page()
|
||||
pages.append(origin_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
|
||||
if self.index_origin:
|
||||
try:
|
||||
origin_page = self._fetch_origin_page()
|
||||
pages.append(origin_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
|
||||
|
||||
return pages
|
||||
|
||||
@@ -298,11 +339,8 @@ class RecursiveIndexer:
|
||||
class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_base: str,
|
||||
space: str,
|
||||
is_cloud: bool,
|
||||
page_id: str = "",
|
||||
index_recursively: bool = True,
|
||||
wiki_page_url: str,
|
||||
index_origin: bool = True,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
# if a page has one of the labels specified in this list, we will just
|
||||
@@ -314,16 +352,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.recursive_indexer: RecursiveIndexer | None = None
|
||||
self.index_recursively = index_recursively
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
self.space = space
|
||||
self.page_id = page_id
|
||||
|
||||
self.is_cloud = is_cloud
|
||||
self.index_origin = index_origin
|
||||
(
|
||||
self.wiki_base,
|
||||
self.space,
|
||||
self.page_id,
|
||||
self.is_cloud,
|
||||
) = extract_confluence_keys_from_url(wiki_page_url)
|
||||
|
||||
self.space_level_scan = False
|
||||
|
||||
self.confluence_client: Confluence | None = None
|
||||
|
||||
if self.page_id is None or self.page_id == "":
|
||||
@@ -331,7 +369,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
logger.info(
|
||||
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
|
||||
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}"
|
||||
+ f" space_level_scan: {self.space_level_scan}, origin: {self.index_origin}"
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -343,6 +381,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
username=username if self.is_cloud else None,
|
||||
password=access_token if self.is_cloud else None,
|
||||
token=access_token if not self.is_cloud else None,
|
||||
cloud=self.is_cloud,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -361,7 +400,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
status=(
|
||||
None if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES else "current"
|
||||
"current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
@@ -382,9 +423,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
start=start_ind + i,
|
||||
limit=1,
|
||||
status=(
|
||||
None
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
|
||||
else "current"
|
||||
"current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
@@ -412,13 +453,10 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
origin_page_id=self.page_id,
|
||||
batch_size=self.batch_size,
|
||||
confluence_client=self.confluence_client,
|
||||
index_recursively=self.index_recursively,
|
||||
index_origin=self.index_origin,
|
||||
)
|
||||
|
||||
if self.index_recursively:
|
||||
return self.recursive_indexer.get_pages(start_ind, batch_size)
|
||||
else:
|
||||
return self.recursive_indexer.get_origin_page()
|
||||
return self.recursive_indexer.get_pages(start_ind, batch_size)
|
||||
|
||||
pages: list[dict[str, Any]] = []
|
||||
|
||||
@@ -491,249 +529,134 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
logger.exception("Ran into exception when fetching labels from Confluence")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _attachment_to_download_link(
|
||||
cls, confluence_client: Confluence, attachment: dict[str, Any]
|
||||
) -> str:
|
||||
return confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
@classmethod
|
||||
def _attachment_to_content(
|
||||
cls,
|
||||
confluence_client: Confluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
return None
|
||||
|
||||
download_link = cls._attachment_to_download_link(confluence_client, attachment)
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
attachment["title"], io.BytesIO(response.content), False
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
|
||||
def _fetch_attachments(
|
||||
self, confluence_client: Confluence, page_id: str, files_in_used: list[str]
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
unused_attachments: list = []
|
||||
|
||||
) -> str:
|
||||
get_attachments_from_content = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_attachments_from_content
|
||||
)
|
||||
files_attachment_content: list = []
|
||||
|
||||
try:
|
||||
expand = "history.lastUpdated,metadata.labels"
|
||||
attachments_container = get_attachments_from_content(
|
||||
page_id, start=0, limit=500, expand=expand
|
||||
page_id, start=0, limit=500
|
||||
)
|
||||
for attachment in attachments_container["results"]:
|
||||
if attachment["title"] not in files_in_used:
|
||||
unused_attachments.append(attachment)
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
continue
|
||||
|
||||
attachment_content = self._attachment_to_content(
|
||||
confluence_client, attachment
|
||||
)
|
||||
if attachment_content:
|
||||
files_attachment_content.append(attachment_content)
|
||||
if attachment["title"] not in files_in_used:
|
||||
continue
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
response = confluence_client._session.get(download_link)
|
||||
|
||||
if response.status_code == 200:
|
||||
extract = extract_file_text(
|
||||
attachment["title"], io.BytesIO(response.content), False
|
||||
)
|
||||
files_attachment_content.append(extract)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(
|
||||
e, HTTPError
|
||||
) and NO_PERMISSIONS_TO_VIEW_ATTACHMENTS_ERROR_STR in str(e):
|
||||
logger.warning(
|
||||
f"User does not have access to attachments on page '{page_id}'"
|
||||
)
|
||||
return "", []
|
||||
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
logger.exception(
|
||||
f"Ran into exception when fetching attachments from Confluence: {e}"
|
||||
)
|
||||
|
||||
return "\n".join(files_attachment_content), unused_attachments
|
||||
return "\n".join(files_attachment_content)
|
||||
|
||||
def _get_doc_batch(
|
||||
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> tuple[list[Document], list[dict[str, Any]], int]:
|
||||
) -> tuple[list[Document], int]:
|
||||
doc_batch: list[Document] = []
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
batch = self._fetch_pages(self.confluence_client, start_ind)
|
||||
|
||||
for page in batch:
|
||||
last_modified = _datetime_from_string(page["version"]["when"])
|
||||
last_modified_str = page["version"]["when"]
|
||||
author = cast(str | None, page["version"].get("by", {}).get("email"))
|
||||
last_modified = datetime.fromisoformat(last_modified_str)
|
||||
|
||||
if time_filter and not time_filter(last_modified):
|
||||
continue
|
||||
if last_modified.tzinfo is None:
|
||||
# If no timezone info, assume it is UTC
|
||||
last_modified = last_modified.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# If not in UTC, translate it
|
||||
last_modified = last_modified.astimezone(timezone.utc)
|
||||
|
||||
page_id = page["id"]
|
||||
if time_filter is None or time_filter(last_modified):
|
||||
page_id = page["id"]
|
||||
|
||||
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
|
||||
# check disallowed labels
|
||||
if self.labels_to_skip:
|
||||
label_intersection = self.labels_to_skip.intersection(page_labels)
|
||||
if label_intersection:
|
||||
logger.info(
|
||||
f"Page with ID '{page_id}' has a label which has been "
|
||||
f"designated as disallowed: {label_intersection}. Skipping."
|
||||
)
|
||||
# check disallowed labels
|
||||
if self.labels_to_skip:
|
||||
label_intersection = self.labels_to_skip.intersection(page_labels)
|
||||
if label_intersection:
|
||||
logger.info(
|
||||
f"Page with ID '{page_id}' has a label which has been "
|
||||
f"designated as disallowed: {label_intersection}. Skipping."
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
page_html = (
|
||||
page["body"]
|
||||
.get("storage", page["body"].get("view", {}))
|
||||
.get("value")
|
||||
)
|
||||
page_url = self.wiki_base + page["_links"]["webui"]
|
||||
if not page_html:
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
|
||||
page_html = (
|
||||
page["body"].get("storage", page["body"].get("view", {})).get("value")
|
||||
)
|
||||
page_url = self.wiki_base + page["_links"]["webui"]
|
||||
if not page_html:
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
|
||||
files_in_used = get_used_attachments(page_html, self.confluence_client)
|
||||
attachment_text, unused_page_attachments = self._fetch_attachments(
|
||||
self.confluence_client, page_id, files_in_used
|
||||
)
|
||||
unused_attachments.extend(unused_page_attachments)
|
||||
|
||||
page_text += attachment_text
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=page_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author)] if author else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
files_in_used = get_used_attachments(page_html, self.confluence_client)
|
||||
attachment_text = self._fetch_attachments(
|
||||
self.confluence_client, page_id, files_in_used
|
||||
)
|
||||
)
|
||||
return (
|
||||
doc_batch,
|
||||
unused_attachments,
|
||||
len(batch),
|
||||
)
|
||||
page_text += attachment_text
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
doc_metadata: dict[str, str | list[str]] = {
|
||||
"Wiki Space Name": self.space
|
||||
}
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
def _get_attachment_batch(
|
||||
self,
|
||||
start_ind: int,
|
||||
attachments: list[dict[str, Any]],
|
||||
time_filter: Callable[[datetime], bool] | None = None,
|
||||
) -> tuple[list[Document], int]:
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
end_ind = min(start_ind + self.batch_size, len(attachments))
|
||||
|
||||
for attachment in attachments[start_ind:end_ind]:
|
||||
last_updated = _datetime_from_string(
|
||||
attachment["history"]["lastUpdated"]["when"]
|
||||
)
|
||||
|
||||
if time_filter and not time_filter(last_updated):
|
||||
continue
|
||||
|
||||
attachment_url = self._attachment_to_download_link(
|
||||
self.confluence_client, attachment
|
||||
)
|
||||
attachment_content = self._attachment_to_content(
|
||||
self.confluence_client, attachment
|
||||
)
|
||||
if attachment_content is None:
|
||||
continue
|
||||
|
||||
creator_email = attachment["history"]["createdBy"].get("email")
|
||||
|
||||
comment = attachment["metadata"].get("comment", "")
|
||||
doc_metadata: dict[str, str | list[str]] = {"comment": comment}
|
||||
|
||||
attachment_labels: list[str] = []
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
for label in attachment["metadata"]["labels"]["results"]:
|
||||
attachment_labels.append(label["name"])
|
||||
|
||||
doc_metadata["labels"] = attachment_labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=attachment_url,
|
||||
sections=[Section(link=attachment_url, text=attachment_content)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment["title"],
|
||||
doc_updated_at=last_updated,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=creator_email)]
|
||||
if creator_email
|
||||
else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=page_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author)] if author else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return doc_batch, end_ind - start_ind
|
||||
return doc_batch, len(batch)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
unused_attachments = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
|
||||
start_ind
|
||||
)
|
||||
unused_attachments.extend(unused_attachments_batch)
|
||||
doc_batch, num_pages = self._get_doc_batch(start_ind)
|
||||
start_ind += num_pages
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -741,23 +664,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if num_pages < self.batch_size:
|
||||
break
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
attachment_batch, num_attachments = self._get_attachment_batch(
|
||||
start_ind, unused_attachments
|
||||
)
|
||||
start_ind += num_attachments
|
||||
if attachment_batch:
|
||||
yield attachment_batch
|
||||
|
||||
if num_attachments < self.batch_size:
|
||||
break
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
unused_attachments = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
@@ -766,11 +675,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
|
||||
doc_batch, num_pages = self._get_doc_batch(
|
||||
start_ind, time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
unused_attachments.extend(unused_attachments_batch)
|
||||
|
||||
start_ind += num_pages
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -778,29 +685,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if num_pages < self.batch_size:
|
||||
break
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
attachment_batch, num_attachments = self._get_attachment_batch(
|
||||
start_ind,
|
||||
unused_attachments,
|
||||
time_filter=lambda t: start_time <= t <= end_time,
|
||||
)
|
||||
start_ind += num_attachments
|
||||
if attachment_batch:
|
||||
yield attachment_batch
|
||||
|
||||
if num_attachments < self.batch_size:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = ConfluenceConnector(
|
||||
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
space=os.environ["CONFLUENCE_TEST_SPACE"],
|
||||
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
|
||||
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
|
||||
index_recursively=True,
|
||||
)
|
||||
connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"])
|
||||
connector.load_credentials(
|
||||
{
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
|
||||
@@ -23,33 +23,25 @@ class ConfluenceRateLimitError(Exception):
|
||||
|
||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
max_retries = 5
|
||||
starting_delay = 5
|
||||
backoff = 2
|
||||
max_delay = 600
|
||||
|
||||
for attempt in range(max_retries):
|
||||
for attempt in range(10):
|
||||
try:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
retry_after = None
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
retry_after = int(e.response.headers.get("Retry-After"))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
if retry_after:
|
||||
logger.warning(
|
||||
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
@@ -63,14 +55,5 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
else:
|
||||
# re-raise, let caller handle
|
||||
raise
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
logger.warning(f"Confluence Internal Error, retrying... {e}")
|
||||
delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
time.sleep(delay)
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
raise e
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
TimeRange = tuple[datetime, datetime]
|
||||
|
||||
|
||||
class ConnectorRunner:
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
time_range: TimeRange | None = None,
|
||||
fail_loudly: bool = False,
|
||||
):
|
||||
self.connector = connector
|
||||
|
||||
if isinstance(self.connector, PollConnector):
|
||||
if time_range is None:
|
||||
raise ValueError("time_range is required for PollConnector")
|
||||
|
||||
self.doc_batch_generator = self.connector.poll_source(
|
||||
time_range[0].timestamp(), time_range[1].timestamp()
|
||||
)
|
||||
|
||||
elif isinstance(self.connector, LoadConnector):
|
||||
if time_range and fail_loudly:
|
||||
raise ValueError(
|
||||
"time_range specified, but passed in connector is not a PollConnector"
|
||||
)
|
||||
|
||||
self.doc_batch_generator = self.connector.load_from_state()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
|
||||
|
||||
def run(self) -> GenerateDocumentsOutput:
|
||||
"""Adds additional exception logging to the connector."""
|
||||
try:
|
||||
yield from self.doc_batch_generator
|
||||
except Exception:
|
||||
exc_type, _, exc_traceback = sys.exc_info()
|
||||
|
||||
# Traverse the traceback to find the last frame where the exception was raised
|
||||
tb = exc_traceback
|
||||
if tb is None:
|
||||
logger.error("No traceback found for exception")
|
||||
raise
|
||||
|
||||
while tb.tb_next:
|
||||
tb = tb.tb_next # Move to the next frame in the traceback
|
||||
|
||||
# Get the local variables from the frame where the exception occurred
|
||||
local_vars = tb.tb_frame.f_locals
|
||||
local_vars_str = "\n".join(
|
||||
f"{key}: {value}" for key, value in local_vars.items()
|
||||
)
|
||||
logger.error(
|
||||
f"Error in connector. type: {exc_type};\n"
|
||||
f"local_vars below -> \n{local_vars_str}"
|
||||
)
|
||||
raise
|
||||
@@ -56,7 +56,7 @@ class _RateLimitDecorator:
|
||||
sleep_cnt = 0
|
||||
while len(self.call_history) == self.max_calls:
|
||||
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
|
||||
logger.notice(
|
||||
logger.info(
|
||||
f"Rate limit exceeded for function {func.__name__}. "
|
||||
f"Waiting {sleep_time} seconds before retrying."
|
||||
)
|
||||
|
||||
@@ -45,15 +45,10 @@ def extract_jira_project(url: str) -> tuple[str, str]:
|
||||
return jira_base, jira_project
|
||||
|
||||
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
def extract_text_from_content(content: dict) -> str:
|
||||
texts = []
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in content:
|
||||
for block in content["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
@@ -61,31 +56,24 @@ def extract_text_from_adf(adf: dict | None) -> str:
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
comment_strs = []
|
||||
for comment in jira.fields.comment.comments:
|
||||
try:
|
||||
body_text = (
|
||||
comment.body
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(comment.raw["body"])
|
||||
)
|
||||
if hasattr(comment, "body"):
|
||||
body_text = extract_text_from_content(comment.raw["body"])
|
||||
elif hasattr(comment, "raw"):
|
||||
body = comment.raw.get("body", "No body content available")
|
||||
body_text = (
|
||||
extract_text_from_content(body) if isinstance(body, dict) else body
|
||||
)
|
||||
else:
|
||||
body_text = "No body attribute found"
|
||||
|
||||
if (
|
||||
hasattr(comment, "author")
|
||||
and hasattr(comment.author, "emailAddress")
|
||||
and comment.author.emailAddress in comment_email_blacklist
|
||||
):
|
||||
continue # Skip adding comment if author's email is in blacklist
|
||||
@@ -128,14 +116,9 @@ def fetch_jira_issues_batch(
|
||||
)
|
||||
continue
|
||||
|
||||
description = (
|
||||
jira.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments]
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
@@ -164,18 +147,14 @@ def fetch_jira_issues_batch(
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
if jira.fields.priority:
|
||||
metadata_dict["priority"] = jira.fields.priority.name
|
||||
if jira.fields.status:
|
||||
metadata_dict["status"] = jira.fields.status.name
|
||||
if jira.fields.resolution:
|
||||
metadata_dict["resolution"] = jira.fields.resolution.name
|
||||
if jira.fields.labels:
|
||||
metadata_dict["label"] = jira.fields.labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
|
||||
@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
|
||||
self.permissions: DiscoursePerms | None = None
|
||||
self.active_categories: set | None = None
|
||||
|
||||
@rate_limit_builder(max_calls=50, period=60)
|
||||
@rate_limit_builder(max_calls=100, period=60)
|
||||
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
|
||||
if not self.permissions:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
|
||||
@@ -23,7 +23,7 @@ from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.extract_file_text import get_file_ext
|
||||
from danswer.file_processing.extract_file_text import is_text_file_extension
|
||||
from danswer.file_processing.extract_file_text import load_files_from_zip
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import pdf_to_text
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -75,7 +75,7 @@ def _process_file(
|
||||
|
||||
# Using the PDF reader function directly to pass in password cleanly
|
||||
elif extension == ".pdf":
|
||||
file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass)
|
||||
file_content_raw = pdf_to_text(file=file, pdf_pass=pdf_pass)
|
||||
|
||||
else:
|
||||
file_content_raw = extract_file_text(
|
||||
|
||||
@@ -38,7 +38,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
tzinfo=timezone.utc
|
||||
) - datetime.now(tz=timezone.utc)
|
||||
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
|
||||
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
|
||||
logger.info(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
|
||||
time.sleep(sleep_time.seconds)
|
||||
|
||||
|
||||
|
||||
@@ -11,17 +11,16 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GMAIL_CRED_KEY
|
||||
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.gmail.constants import CRED_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.gmail.constants import GMAIL_CRED_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import GMAIL_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.gmail.constants import SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
@@ -50,7 +49,7 @@ def get_gmail_creds_for_authorized_user(
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logger.notice("Refreshed Gmail tokens.")
|
||||
logger.info("Refreshed Gmail tokens.")
|
||||
return creds
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to refresh gmail access token due to: {e}")
|
||||
@@ -72,7 +71,7 @@ def get_gmail_creds_for_service_account(
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
csrf = get_dynamic_config_store().load(CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Gmail Connector callback does not match expected"
|
||||
@@ -80,7 +79,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
|
||||
|
||||
|
||||
def get_gmail_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
@@ -92,14 +91,12 @@ def get_gmail_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
@@ -111,9 +108,7 @@ def get_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
@@ -125,7 +120,7 @@ def update_gmail_credential_access_tokens(
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_gmail_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
app_credentials.dict(),
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_gmail_redirect(),
|
||||
)
|
||||
@@ -151,29 +146,28 @@ def build_service_account_creds(
|
||||
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
|
||||
|
||||
return CredentialBase(
|
||||
source=DocumentSource.GMAIL,
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
)
|
||||
|
||||
|
||||
def get_google_app_gmail_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_google_app_gmail_cred() -> None:
|
||||
get_dynamic_config_store().delete(KV_GMAIL_CRED_KEY)
|
||||
get_dynamic_config_store().delete(GMAIL_CRED_KEY)
|
||||
|
||||
|
||||
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
@@ -181,19 +175,19 @@ def upsert_gmail_service_account_key(
|
||||
service_account_key: GoogleServiceAccountKey,
|
||||
) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_gmail_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
get_dynamic_config_store().delete(GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
get_dynamic_config_store().delete(GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "gmail_tokens"
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "gmail_delegated_user"
|
||||
CRED_KEY = "credential_id_{}"
|
||||
GMAIL_CRED_KEY = "gmail_app_credential"
|
||||
GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
|
||||
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
|
||||
@@ -81,10 +81,10 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
|
||||
for workspace in workspace_list:
|
||||
if workspace:
|
||||
logger.info(f"Updating Gong workspace: {workspace}")
|
||||
logger.info(f"Updating workspace: {workspace}")
|
||||
workspace_id = workspace_map.get(workspace)
|
||||
if not workspace_id:
|
||||
logger.error(f"Invalid Gong workspace: {workspace}")
|
||||
logger.error(f"Invalid workspace: {workspace}")
|
||||
if not self.continue_on_fail:
|
||||
raise ValueError(f"Invalid workspace: {workspace}")
|
||||
continue
|
||||
|
||||
@@ -41,8 +41,8 @@ from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import docx_to_text
|
||||
from danswer.file_processing.extract_file_text import pdf_to_text
|
||||
from danswer.file_processing.extract_file_text import pptx_to_text
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -62,8 +62,6 @@ class GDriveMimeType(str, Enum):
|
||||
POWERPOINT = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
PLAIN_TEXT = "text/plain"
|
||||
MARKDOWN = "text/markdown"
|
||||
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
@@ -269,7 +267,7 @@ def get_all_files_batched(
|
||||
yield from batch_generator(
|
||||
items=found_files,
|
||||
batch_size=batch_size,
|
||||
pre_batch_yield=lambda batch_files: logger.debug(
|
||||
pre_batch_yield=lambda batch_files: logger.info(
|
||||
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
|
||||
),
|
||||
)
|
||||
@@ -308,42 +306,36 @@ def get_all_files_batched(
|
||||
|
||||
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
mime_type = file["mimeType"]
|
||||
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
GDriveMimeType.SPREADSHEET.value,
|
||||
]:
|
||||
export_mime_type = (
|
||||
"text/plain"
|
||||
if mime_type != GDriveMimeType.SPREADSHEET.value
|
||||
else "text/csv"
|
||||
)
|
||||
if mime_type == GDriveMimeType.DOC.value:
|
||||
return (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType=export_mime_type)
|
||||
.export(fileId=file["id"], mimeType="text/plain")
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
elif mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
return (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType="text/csv")
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
elif mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return text
|
||||
return pdf_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PPT.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
|
||||
@@ -11,10 +11,7 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_drive.constants import CRED_KEY
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
@@ -22,6 +19,8 @@ from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.google_drive.constants import GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.connectors.google_drive.constants import GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_drive.constants import SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
@@ -50,7 +49,7 @@ def get_google_drive_creds_for_authorized_user(
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logger.notice("Refreshed Google Drive tokens.")
|
||||
logger.info("Refreshed Google Drive tokens.")
|
||||
return creds
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to refresh google drive access token due to: {e}")
|
||||
@@ -72,7 +71,7 @@ def get_google_drive_creds_for_service_account(
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
csrf = get_dynamic_config_store().load(CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
@@ -80,7 +79,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
@@ -92,9 +91,7 @@ def get_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
@@ -106,7 +103,7 @@ def update_credential_access_tokens(
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
app_credentials.dict(),
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
@@ -121,7 +118,6 @@ def update_credential_access_tokens(
|
||||
|
||||
|
||||
def build_service_account_creds(
|
||||
source: DocumentSource,
|
||||
delegated_user_email: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_service_account_key()
|
||||
@@ -135,37 +131,34 @@ def build_service_account_creds(
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
|
||||
|
||||
def get_google_app_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_google_app_cred() -> None:
|
||||
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
get_dynamic_config_store().delete(GOOGLE_DRIVE_CRED_KEY)
|
||||
|
||||
|
||||
def get_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(
|
||||
get_dynamic_config_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
get_dynamic_config_store().delete(GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user