mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 00:35:46 +00:00
Compare commits
1 Commits
github-act
...
web-docker
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5da0751ac6 |
76
.github/actions/custom-build-and-push/action.yml
vendored
76
.github/actions/custom-build-and-push/action.yml
vendored
@@ -1,76 +0,0 @@
|
||||
name: 'Build and Push Docker Image with Retry'
|
||||
description: 'Attempts to build and push a Docker image, with a retry on failure'
|
||||
inputs:
|
||||
context:
|
||||
description: 'Build context'
|
||||
required: true
|
||||
file:
|
||||
description: 'Dockerfile location'
|
||||
required: true
|
||||
platforms:
|
||||
description: 'Target platforms'
|
||||
required: true
|
||||
pull:
|
||||
description: 'Always attempt to pull a newer version of the image'
|
||||
required: false
|
||||
default: 'true'
|
||||
push:
|
||||
description: 'Push the image to registry'
|
||||
required: false
|
||||
default: 'true'
|
||||
load:
|
||||
description: 'Load the image into Docker daemon'
|
||||
required: false
|
||||
default: 'true'
|
||||
tags:
|
||||
description: 'Image tags'
|
||||
required: true
|
||||
cache-from:
|
||||
description: 'Cache sources'
|
||||
required: false
|
||||
cache-to:
|
||||
description: 'Cache destinations'
|
||||
required: false
|
||||
retry-wait-time:
|
||||
description: 'Time to wait before retry in seconds'
|
||||
required: false
|
||||
default: '5'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build and push Docker image (First Attempt)
|
||||
id: buildx1
|
||||
uses: docker/build-push-action@v5
|
||||
continue-on-error: true
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
- name: Wait to retry
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
run: |
|
||||
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
|
||||
sleep ${{ inputs.retry-wait-time }}
|
||||
shell: bash
|
||||
|
||||
- name: Build and push Docker image (Retry Attempt)
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
25
.github/pull_request_template.md
vendored
25
.github/pull_request_template.md
vendored
@@ -1,25 +0,0 @@
|
||||
## Description
|
||||
[Provide a brief description of the changes in this PR]
|
||||
|
||||
|
||||
## How Has This Been Tested?
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk
|
||||
[Any know risks or failure modes to point out to reviewers]
|
||||
|
||||
|
||||
## Related Issue(s)
|
||||
[If applicable, link to the issue(s) this PR addresses]
|
||||
|
||||
|
||||
## Checklist:
|
||||
- [ ] All of the automated tests pass
|
||||
- [ ] All PR comments are addressed and marked resolved
|
||||
- [ ] If there are migrations, they have been rebased to latest main
|
||||
- [ ] If there are new dependencies, they are added to the requirements
|
||||
- [ ] If there are new environment variables, they are added to all of the deployment methods
|
||||
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- [ ] Docker images build and basic functionalities work
|
||||
- [ ] Author has done a final read through of the PR right before merge
|
||||
23
.github/workflows/check-backend-changes.yml
vendored
23
.github/workflows/check-backend-changes.yml
vendored
@@ -1,23 +0,0 @@
|
||||
name: Check Backend Changes
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
outputs:
|
||||
run-tests:
|
||||
description: "Whether to run tests based on backend changes"
|
||||
value: ${{ jobs.check-run-needed.outputs.run-tests }}
|
||||
|
||||
jobs:
|
||||
check-run-needed:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
run-tests: ${{ steps.check.outputs.run-tests }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- id: check
|
||||
run: |
|
||||
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} | grep -q '^backend/'; then
|
||||
echo "run-tests=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "run-tests=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
@@ -5,38 +5,33 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# TODO: make this a matrix build like the web containers
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.REGISTRY_IMAGE }}:latest
|
||||
danswer/danswer-backend:${{ github.ref_name }}
|
||||
danswer/danswer-backend:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
@@ -44,6 +39,6 @@ jobs:
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
trivyignores: ./backend/.trivyignore
|
||||
|
||||
@@ -7,24 +7,23 @@ on:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
|
||||
@@ -5,115 +5,38 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
merge:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- build
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Web Image Docker Build and Push
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
danswer/danswer-web-server:${{ github.ref_name }}
|
||||
danswer/danswer-web-server:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-web-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
# This workflow is intentionally disabled while we're still working on it
|
||||
# It's close to ready, but a race condition needs to be fixed with
|
||||
# API server and Vespa startup, and it needs to have a way to build/test against
|
||||
# local containers
|
||||
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
runs-on: Amd64
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Run chart-testing (lint)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --config ct.yaml
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
13
.github/workflows/pr-python-checks.yml
vendored
13
.github/workflows/pr-python-checks.yml
vendored
@@ -1,17 +1,11 @@
|
||||
name: Python Checks
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
check-changes:
|
||||
uses: ./.github/workflows/check-backend-changes.yml
|
||||
|
||||
mypy-check:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.run-tests == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@@ -52,10 +46,3 @@ jobs:
|
||||
run: |
|
||||
cd backend
|
||||
black --check .
|
||||
|
||||
skip-tests:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.run-tests == 'false'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: echo "No changes in backend, skipping this test."
|
||||
|
||||
69
.github/workflows/pr-python-connector-tests.yml
vendored
69
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -1,69 +0,0 @@
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
schedule:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
env:
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
check-changes:
|
||||
uses: ./.github/workflows/check-backend-changes.yml
|
||||
|
||||
connectors-check:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.run-tests == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
|
||||
skip-tests:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.run-tests == 'false'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: echo "No changes in backend, skipping this test."
|
||||
14
.github/workflows/pr-python-tests.yml
vendored
14
.github/workflows/pr-python-tests.yml
vendored
@@ -1,17 +1,11 @@
|
||||
name: Python Unit Tests
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
check-changes:
|
||||
uses: ./.github/workflows/check-backend-changes.yml
|
||||
|
||||
backend-check:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.run-tests == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
@@ -39,11 +33,3 @@ jobs:
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/unit
|
||||
|
||||
|
||||
skip-tests:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.run-tests == 'false'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: echo "No changes in backend, skipping this test."
|
||||
|
||||
19
.github/workflows/pr-quality-checks.yml
vendored
19
.github/workflows/pr-quality-checks.yml
vendored
@@ -4,19 +4,18 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request: null
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
with:
|
||||
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
with:
|
||||
extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
172
.github/workflows/run-it.yml
vendored
172
.github/workflows/run-it.yml
vendored
@@ -1,172 +0,0 @@
|
||||
name: Run Integration Tests
|
||||
concurrency:
|
||||
group: Run-Integration-Tests-${{ github.head_ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
check-changes:
|
||||
uses: ./.github/workflows/check-backend-changes.yml
|
||||
|
||||
integration-tests:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.run-tests == 'true'
|
||||
runs-on:
|
||||
group: 'arm64-image-builders'
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# NOTE: we don't need to build the Web Docker image since it's not used
|
||||
# during the IT for now. We have a separate action to verify it builds
|
||||
# succesfully
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull danswer/danswer-web-server:latest
|
||||
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it
|
||||
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: danswer/danswer-backend:it
|
||||
cache-from: type=registry,ref=danswer/danswer-backend:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/danswer-backend:it,mode=max
|
||||
type=inline
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: danswer/danswer-model-server:it
|
||||
cache-from: type=registry,ref=danswer/danswer-model-server:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/danswer-model-server:it,mode=max
|
||||
type=inline
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: danswer/integration-test-runner:it
|
||||
cache-from: type=registry,ref=danswer/integration-test-runner:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/integration-test-runner:it,mode=max
|
||||
type=inline
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=it \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run integration tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
danswer/integration-test-runner:it
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: docker-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
skip-tests:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.run-tests == 'false'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: echo "No changes in backend, skipping this test."
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -4,6 +4,4 @@
|
||||
.mypy_cache
|
||||
.idea
|
||||
/deployment/data/nginx/app.conf
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
.vscode/launch.json
|
||||
|
||||
51
.vscode/env_template.txt
vendored
51
.vscode/env_template.txt
vendored
@@ -1,51 +0,0 @@
|
||||
# Copy this file to .env in the .vscode folder
|
||||
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
|
||||
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
|
||||
|
||||
# For local dev, often user Authentication is not needed
|
||||
AUTH_TYPE=disabled
|
||||
|
||||
|
||||
# Always keep these on for Dev
|
||||
# Logs all model prompts to stdout
|
||||
LOG_DANSWER_MODEL_INTERACTIONS=True
|
||||
# More verbose logging
|
||||
LOG_LEVEL=debug
|
||||
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to answer generation
|
||||
# This step is quite heavy on token usage so we disable it for dev generally
|
||||
DISABLE_LLM_DOC_RELEVANCE=False
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
|
||||
REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using DanswerBot
|
||||
#DANSWER_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
|
||||
#DANSWER_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
|
||||
|
||||
|
||||
# Python stuff
|
||||
PYTHONPATH=../backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY=<REPLACE THIS>
|
||||
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
78
.vscode/launch.template.jsonc
vendored
78
.vscode/launch.template.jsonc
vendored
@@ -1,23 +1,15 @@
|
||||
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
|
||||
/*
|
||||
|
||||
Copy this file into '.vscode/launch.json' or merge its
|
||||
contents into your existing configurations.
|
||||
|
||||
*/
|
||||
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
"name": "Run All Danswer Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Indexing",
|
||||
"Background Jobs",
|
||||
"Slack Bot"
|
||||
]
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Web Server",
|
||||
@@ -25,7 +17,6 @@
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
@@ -33,12 +24,10 @@
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
@@ -52,14 +41,12 @@
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_ALL_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
@@ -72,14 +59,12 @@
|
||||
},
|
||||
{
|
||||
"name": "Indexing",
|
||||
"consoleName": "Indexing",
|
||||
"type": "debugpy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"ENABLE_MINI_CHUNK": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
@@ -88,14 +73,11 @@
|
||||
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
|
||||
{
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
@@ -108,46 +90,16 @@
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/danswerbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -48,24 +48,20 @@ We would love to see you there!
|
||||
|
||||
|
||||
## Get Started 🚀
|
||||
Danswer being a fully functional app, relies on some external software, specifically:
|
||||
Danswer being a fully functional app, relies on some external pieces of software, specifically:
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for
|
||||
> development purposes. However, you can also use the containers and update with local changes by providing the
|
||||
> `--build` flag.
|
||||
This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for
|
||||
development purposes but also feel free to just use the containers and update with local changes by providing the
|
||||
`--build` flag.
|
||||
|
||||
|
||||
### Local Set Up
|
||||
Be sure to use Python version 3.11.
|
||||
It is recommended to use Python version 3.11
|
||||
|
||||
If using a lower version, modifications will have to be made to the code.
|
||||
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
|
||||
If using a higher version, the version of Tensorflow we use may not be available for your platform.
|
||||
|
||||
|
||||
#### Installing Requirements
|
||||
@@ -76,11 +72,6 @@ For convenience here's a command for it:
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs.
|
||||
> For simplicity, we recommend setting up the virtual environment outside of the danswer directory.
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
@@ -94,22 +85,19 @@ Install the required python dependencies:
|
||||
```bash
|
||||
pip install -r danswer/backend/requirements/default.txt
|
||||
pip install -r danswer/backend/requirements/dev.txt
|
||||
pip install -r danswer/backend/requirements/ee.txt
|
||||
pip install -r danswer/backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
|
||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||
Once the above is done, navigate to `danswer/web` run:
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
Install Playwright (headless browser required by the Web Connector)
|
||||
Install Playwright (required by the Web Connector)
|
||||
|
||||
> **Note:**
|
||||
> If you have just run the pip install, open a new terminal and source the python virtual-env again.
|
||||
> This will pull the updated PATH to include playwright
|
||||
> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again.
|
||||
This will update the path to include playwright
|
||||
|
||||
Then install Playwright by running:
|
||||
```bash
|
||||
@@ -118,14 +106,11 @@ playwright install
|
||||
|
||||
|
||||
#### Dependent Docker Containers
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up Vespa and Postgres with:
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db
|
||||
```
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
(index refers to Vespa and relational_db refers to Postgres)
|
||||
|
||||
#### Running Danswer
|
||||
To start the frontend, navigate to `danswer/web` and run:
|
||||
@@ -138,10 +123,11 @@ Navigate to `danswer/backend` and run:
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
powershell -Command "
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
"
|
||||
```
|
||||
|
||||
The first time running Danswer, you will need to run the DB migrations for Postgres.
|
||||
@@ -164,7 +150,6 @@ To run the backend API server, navigate back to `danswer/backend` and run:
|
||||
```bash
|
||||
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "
|
||||
@@ -173,28 +158,20 @@ powershell -Command "
|
||||
"
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
Note: if you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
### Formatting and Linting
|
||||
#### Backend
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `danswer/backend` directory, run:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Danswer is fully type-annotated, and we want to keep it that way!
|
||||
Danswer is fully type-annotated, and we would like to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
|
||||
|
||||
|
||||
@@ -205,7 +182,6 @@ Please double check that prettier passes before creating a pull request.
|
||||
|
||||
|
||||
### Release Process
|
||||
Danswer loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
Danswer follows the semver versioning standard.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
## Some additional notes for Mac Users
|
||||
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Setting up Python
|
||||
Ensure [Homebrew](https://brew.sh/) is already set up.
|
||||
|
||||
Then install python 3.11.
|
||||
```bash
|
||||
brew install python@3.11
|
||||
```
|
||||
|
||||
Add python 3.11 to your path: add the following line to ~/.zshrc
|
||||
```
|
||||
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> You will need to open a new terminal for the path change above to take effect.
|
||||
|
||||
|
||||
### Setting up Docker
|
||||
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
|
||||
ensure it is running before continuing with the docker commands.
|
||||
|
||||
|
||||
### Formatting and Linting
|
||||
MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly.
|
||||
After installing pre-commit, run the following command:
|
||||
```bash
|
||||
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
```
|
||||
8
LICENSE
8
LICENSE
@@ -1,10 +1,6 @@
|
||||
Copyright (c) 2023-present DanswerAI, Inc.
|
||||
MIT License
|
||||
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
* All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
* All third party components incorporated into the Danswer Software are licensed under the original license provided by the owner of the applicable component.
|
||||
* Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
Copyright (c) 2023 Yuhong Sun, Chris Weaver
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
22
README.md
22
README.md
@@ -11,7 +11,7 @@
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
@@ -105,25 +105,5 @@ Efficiently pulls the latest changes from:
|
||||
* Websites
|
||||
* And more ...
|
||||
|
||||
## 📚 Editions
|
||||
|
||||
There are two editions of Danswer:
|
||||
|
||||
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
|
||||
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
* Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
* Role-based access control
|
||||
* Document permission inheritance from connected sources
|
||||
* Usage analytics and query history accessible to admins
|
||||
* Whitelabeling
|
||||
* API key authentication
|
||||
* Encryption of secrets
|
||||
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
|
||||
|
||||
To try the Danswer Enterprise Edition:
|
||||
|
||||
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
|
||||
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
2
backend/.gitignore
vendored
2
backend/.gitignore
vendored
@@ -5,7 +5,7 @@ site_crawls/
|
||||
.ipynb_checkpoints/
|
||||
api_keys.py
|
||||
*ipynb
|
||||
.env*
|
||||
.env
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule*
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
LABEL com.danswer.maintainer="founders@danswer.ai"
|
||||
LABEL com.danswer.description="This image is the web/frontend container of Danswer which \
|
||||
contains code for both the Community and Enterprise editions of Danswer. If you do not \
|
||||
have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \
|
||||
Edition features outside of personal development or testing purposes. Please reach out to \
|
||||
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
|
||||
LABEL com.danswer.description="This image is for the backend of Danswer. It is MIT Licensed and \
|
||||
free for all to use. You can find it at https://hub.docker.com/r/danswer/danswer-backend. For \
|
||||
more details, visit https://github.com/danswer-ai/danswer."
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
# libpq-dev needed for psycopg (postgres)
|
||||
@@ -19,32 +17,18 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# zip for Vespa step futher down
|
||||
# ca-certificates for HTTPS
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake \
|
||||
curl \
|
||||
zip \
|
||||
ca-certificates \
|
||||
libgnutls30=3.7.9-2+deb12u3 \
|
||||
libblkid1=2.38.1-5+deb12u1 \
|
||||
libmount1=2.38.1-5+deb12u1 \
|
||||
libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc && \
|
||||
apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 \
|
||||
libblkid1=2.38.1-5+deb12u1 libmount1=2.38.1-5+deb12u1 libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
-r /tmp/requirements.txt \
|
||||
-r /tmp/ee-requirements.txt && \
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
||||
pip uninstall -y py && \
|
||||
playwright install chromium && \
|
||||
playwright install-deps chromium && \
|
||||
playwright install chromium && playwright install-deps chromium && \
|
||||
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
|
||||
|
||||
# Cleanup for CVEs and size reduction
|
||||
@@ -52,52 +36,29 @@ RUN pip install --no-cache-dir --upgrade \
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
RUN apt-get update && \
|
||||
apt-get remove -y --allow-remove-essential \
|
||||
perl-base \
|
||||
xserver-common \
|
||||
xvfb \
|
||||
cmake \
|
||||
libldap-2.5-0 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc && \
|
||||
apt-get install -y libxmlsec1-openssl && \
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \
|
||||
libldap-2.5-0 libldap-2.5-0 && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('wordnet', quiet=True); \
|
||||
nltk.download('punkt', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
|
||||
# Enterprise Version Files
|
||||
COPY ./ee /app/ee
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# Set up application files
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
# Escape hatch
|
||||
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
|
||||
# Put logo in assets
|
||||
COPY ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
|
||||
# Default command which does nothing
|
||||
|
||||
@@ -18,22 +18,14 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
# Download tokenizers, distilbert for the Danswer model
|
||||
# Download model weights
|
||||
# Run Nomic to pull in the custom architecture and have it cached locally
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
|
||||
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
|
||||
# running Danswer, don't overwrite it with the built in cache folder
|
||||
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
|
||||
AutoTokenizer.from_pretrained('danswer/intent-model'); \
|
||||
AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
snapshot_download('danswer/intent-model'); \
|
||||
snapshot_download('intfloat/e5-base-v2'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')"
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@@ -16,9 +15,7 @@ config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
@@ -32,20 +29,6 @@ target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
@@ -72,11 +55,7 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
) # type: ignore
|
||||
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Add thread specific model selection
|
||||
|
||||
Revision ID: 0568ccf46a6b
|
||||
Revises: e209dc5a8156
|
||||
Create Date: 2024-06-19 14:25:36.376046
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0568ccf46a6b"
|
||||
down_revision = "e209dc5a8156"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("current_alternate_model", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "current_alternate_model")
|
||||
@@ -1,32 +0,0 @@
|
||||
"""add search doc relevance details
|
||||
|
||||
Revision ID: 05c07bf07c00
|
||||
Revises: b896bbd0d5a7
|
||||
Create Date: 2024-07-10 17:48:15.886653
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "05c07bf07c00"
|
||||
down_revision = "b896bbd0d5a7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_doc",
|
||||
sa.Column("is_relevant", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"search_doc",
|
||||
sa.Column("relevance_explanation", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_doc", "relevance_explanation")
|
||||
op.drop_column("search_doc", "is_relevant")
|
||||
@@ -1,26 +0,0 @@
|
||||
"""add_indexing_start_to_connector
|
||||
|
||||
Revision ID: 08a1eda20fe1
|
||||
Revises: 8a87bd6ec550
|
||||
Create Date: 2024-07-23 11:12:39.462397
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "08a1eda20fe1"
|
||||
down_revision = "8a87bd6ec550"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector", sa.Column("indexing_start", sa.DateTime(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector", "indexing_start")
|
||||
@@ -1,135 +0,0 @@
|
||||
"""embedding model -> search settings
|
||||
|
||||
Revision ID: 1f60f60c3401
|
||||
Revises: f17bf3b0d9f1
|
||||
Create Date: 2024-08-25 12:39:51.731632
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1f60f60c3401"
|
||||
down_revision = "f17bf3b0d9f1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
# Rename the table
|
||||
op.rename_table("embedding_model", "search_settings")
|
||||
|
||||
# Add new columns
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"multipass_indexing", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"multilingual_expansion",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"disable_rerank_for_streaming",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"num_rerank",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=str(NUM_POSTPROCESSED_RESULTS),
|
||||
),
|
||||
)
|
||||
|
||||
# Add the new column as nullable initially
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("search_settings_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the new column with data from the existing embedding_model_id
|
||||
op.execute("UPDATE index_attempt SET search_settings_id = embedding_model_id")
|
||||
|
||||
# Create the foreign key constraint
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Make the new column non-nullable
|
||||
op.alter_column("index_attempt", "search_settings_id", nullable=False)
|
||||
|
||||
# Drop the old embedding_model_id column
|
||||
op.drop_column("index_attempt", "embedding_model_id")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the embedding_model_id column
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the old column with data from search_settings_id
|
||||
op.execute("UPDATE index_attempt SET embedding_model_id = search_settings_id")
|
||||
|
||||
# Make the old column non-nullable
|
||||
op.alter_column("index_attempt", "embedding_model_id", nullable=False)
|
||||
|
||||
# Drop the foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Drop the new search_settings_id column
|
||||
op.drop_column("index_attempt", "search_settings_id")
|
||||
|
||||
# Rename the table back
|
||||
op.rename_table("search_settings", "embedding_model")
|
||||
|
||||
# Remove added columns
|
||||
op.drop_column("embedding_model", "num_rerank")
|
||||
op.drop_column("embedding_model", "rerank_api_key")
|
||||
op.drop_column("embedding_model", "rerank_provider_type")
|
||||
op.drop_column("embedding_model", "rerank_model_name")
|
||||
op.drop_column("embedding_model", "disable_rerank_for_streaming")
|
||||
op.drop_column("embedding_model", "multilingual_expansion")
|
||||
op.drop_column("embedding_model", "multipass_indexing")
|
||||
|
||||
op.create_foreign_key(
|
||||
"index_attempt__embedding_model_fk",
|
||||
"index_attempt",
|
||||
"embedding_model",
|
||||
["embedding_model_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,44 +0,0 @@
|
||||
"""notifications
|
||||
|
||||
Revision ID: 213fd978c6d8
|
||||
Revises: 5fc1f54cc252
|
||||
Create Date: 2024-08-10 11:13:36.070790
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "213fd978c6d8"
|
||||
down_revision = "5fc1f54cc252"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"notification",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"notif_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sa.UUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("dismissed", sa.Boolean(), nullable=False),
|
||||
sa.Column("last_shown", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("first_shown", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("notification")
|
||||
@@ -1,86 +0,0 @@
|
||||
"""remove-feedback-foreignkey-constraint
|
||||
|
||||
Revision ID: 23957775e5f5
|
||||
Revises: bc9771dccadf
|
||||
Create Date: 2024-06-27 16:04:51.480437
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "23957775e5f5"
|
||||
down_revision = "bc9771dccadf"
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_feedback__chat_message_fk",
|
||||
"chat_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=True
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.alter_column(
|
||||
"document_retrieval_feedback",
|
||||
"chat_message_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=False
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_feedback__chat_message_fk",
|
||||
"chat_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"document_retrieval_feedback",
|
||||
"chat_message_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -160,28 +160,12 @@ def downgrade() -> None:
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if the constraint exists before dropping
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
constraints = inspector.get_foreign_keys("index_attempt")
|
||||
|
||||
if any(
|
||||
constraint["name"] == "fk_index_attempt_credential_id"
|
||||
for constraint in constraints
|
||||
):
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
if any(
|
||||
constraint["name"] == "fk_index_attempt_connector_id"
|
||||
for constraint in constraints
|
||||
):
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_table("connector_credential_pair")
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
"""Add Above Below to Persona
|
||||
|
||||
Revision ID: 2d2304e27d8c
|
||||
Revises: 4b08d97e175a
|
||||
Create Date: 2024-08-21 19:15:15.762948
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2d2304e27d8c"
|
||||
down_revision = "4b08d97e175a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("chunks_above", sa.Integer(), nullable=True))
|
||||
op.add_column("persona", sa.Column("chunks_below", sa.Integer(), nullable=True))
|
||||
|
||||
op.execute(
|
||||
"UPDATE persona SET chunks_above = 1, chunks_below = 1 WHERE chunks_above IS NULL AND chunks_below IS NULL"
|
||||
)
|
||||
|
||||
op.alter_column("persona", "chunks_above", nullable=False)
|
||||
op.alter_column("persona", "chunks_below", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "chunks_below")
|
||||
op.drop_column("persona", "chunks_above")
|
||||
@@ -1,70 +0,0 @@
|
||||
"""Add icon_color and icon_shape to Persona
|
||||
|
||||
Revision ID: 325975216eb3
|
||||
Revises: 91ffac7e65b3
|
||||
Create Date: 2024-07-24 21:29:31.784562
|
||||
|
||||
"""
|
||||
import random
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column, select
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "325975216eb3"
|
||||
down_revision = "91ffac7e65b3"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
colorOptions = [
|
||||
"#FF6FBF",
|
||||
"#6FB1FF",
|
||||
"#B76FFF",
|
||||
"#FFB56F",
|
||||
"#6FFF8D",
|
||||
"#FF6F6F",
|
||||
"#6FFFFF",
|
||||
]
|
||||
|
||||
|
||||
# Function to generate a random shape ensuring at least 3 of the middle 4 squares are filled
|
||||
def generate_random_shape() -> int:
|
||||
center_squares = [12, 10, 6, 14, 13, 11, 7, 15]
|
||||
center_fill = random.choice(center_squares)
|
||||
remaining_squares = [i for i in range(16) if not (center_fill & (1 << i))]
|
||||
random.shuffle(remaining_squares)
|
||||
for i in range(10 - bin(center_fill).count("1")):
|
||||
center_fill |= 1 << remaining_squares[i]
|
||||
return center_fill
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True))
|
||||
op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True))
|
||||
op.add_column("persona", sa.Column("uploaded_image_id", sa.String(), nullable=True))
|
||||
|
||||
persona = table(
|
||||
"persona",
|
||||
column("id", sa.Integer),
|
||||
column("icon_color", sa.String),
|
||||
column("icon_shape", sa.Integer),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
personas = conn.execute(select(persona.c.id))
|
||||
|
||||
for persona_id in personas:
|
||||
random_color = random.choice(colorOptions)
|
||||
random_shape = generate_random_shape()
|
||||
conn.execute(
|
||||
persona.update()
|
||||
.where(persona.c.id == persona_id[0])
|
||||
.values(icon_color=random_color, icon_shape=random_shape)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "icon_shape")
|
||||
op.drop_column("persona", "uploaded_image_id")
|
||||
op.drop_column("persona", "icon_color")
|
||||
@@ -1,90 +0,0 @@
|
||||
"""Add curator fields
|
||||
|
||||
Revision ID: 351faebd379d
|
||||
Revises: ee3f4b47fad5
|
||||
Create Date: 2024-08-15 22:37:08.397052
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "351faebd379d"
|
||||
down_revision = "ee3f4b47fad5"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add is_curator column to User__UserGroup table
|
||||
op.add_column(
|
||||
"user__user_group",
|
||||
sa.Column("is_curator", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
|
||||
# Use batch mode to modify the enum type
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC",
|
||||
"ADMIN",
|
||||
"CURATOR",
|
||||
"GLOBAL_CURATOR",
|
||||
name="userrole",
|
||||
native_enum=False,
|
||||
),
|
||||
existing_type=sa.Enum("BASIC", "ADMIN", name="userrole", native_enum=False),
|
||||
existing_nullable=False,
|
||||
)
|
||||
# Create the association table
|
||||
op.create_table(
|
||||
"credential__user_group",
|
||||
sa.Column("credential_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["credential_id"],
|
||||
["credential.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("credential_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"curator_public", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Update existing records to ensure they fit within the BASIC/ADMIN roles
|
||||
op.execute(
|
||||
"UPDATE \"user\" SET role = 'ADMIN' WHERE role IN ('CURATOR', 'GLOBAL_CURATOR')"
|
||||
)
|
||||
|
||||
# Remove is_curator column from User__UserGroup table
|
||||
op.drop_column("user__user_group", "is_curator")
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
|
||||
),
|
||||
existing_type=sa.Enum(
|
||||
"BASIC",
|
||||
"ADMIN",
|
||||
"CURATOR",
|
||||
"GLOBAL_CURATOR",
|
||||
name="userrole",
|
||||
native_enum=False,
|
||||
),
|
||||
existing_nullable=False,
|
||||
)
|
||||
# Drop the association table
|
||||
op.drop_table("credential__user_group")
|
||||
op.drop_column("credential", "curator_public")
|
||||
@@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3879338f8ba1"
|
||||
down_revision = "f1c6478c3fd8"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
"""add alternate assistant to chat message
|
||||
|
||||
Revision ID: 3a7802814195
|
||||
Revises: 23957775e5f5
|
||||
Create Date: 2024-06-05 11:18:49.966333
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3a7802814195"
|
||||
down_revision = "23957775e5f5"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_persona",
|
||||
"chat_message",
|
||||
"persona",
|
||||
["alternate_assistant_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "alternate_assistant_id")
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Rename index_origin to index_recursively
|
||||
|
||||
Revision ID: 1d6ad76d1f37
|
||||
Revises: e1392f05e840
|
||||
Create Date: 2024-08-01 12:38:54.466081
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d6ad76d1f37"
|
||||
down_revision = "e1392f05e840"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{index_recursively}',
|
||||
'true'::jsonb
|
||||
) - 'index_origin'
|
||||
WHERE connector_specific_config ? 'index_origin'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{index_origin}',
|
||||
connector_specific_config->'index_recursively'
|
||||
) - 'index_recursively'
|
||||
WHERE connector_specific_config ? 'index_recursively'
|
||||
"""
|
||||
)
|
||||
@@ -1,65 +0,0 @@
|
||||
"""add cloud embedding model and update embedding_model
|
||||
|
||||
Revision ID: 44f856ae2a4a
|
||||
Revises: d716b0791ddd
|
||||
Create Date: 2024-06-28 20:01:05.927647
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "44f856ae2a4a"
|
||||
down_revision = "d716b0791ddd"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create embedding_provider table
|
||||
op.create_table(
|
||||
"embedding_provider",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("default_model_id", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
# Add cloud_provider_id to embedding_model table
|
||||
op.add_column(
|
||||
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Add foreign key constraints
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["cloud_provider_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_provider_default_model",
|
||||
"embedding_provider",
|
||||
"embedding_model",
|
||||
["default_model_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove foreign key constraints
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Remove cloud_provider_id column
|
||||
op.drop_column("embedding_model", "cloud_provider_id")
|
||||
|
||||
# Drop embedding_provider table
|
||||
op.drop_table("embedding_provider")
|
||||
@@ -1,23 +0,0 @@
|
||||
"""added is_internet to DBDoc
|
||||
|
||||
Revision ID: 4505fd7302e1
|
||||
Revises: c18cdf4b497e
|
||||
Create Date: 2024-06-18 20:46:09.095034
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4505fd7302e1"
|
||||
down_revision = "c18cdf4b497e"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True))
|
||||
op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool", "display_name")
|
||||
op.drop_column("search_doc", "is_internet")
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Add display_model_names to llm_provider
|
||||
|
||||
Revision ID: 473a1a7ca408
|
||||
Revises: 325975216eb3
|
||||
Create Date: 2024-07-25 14:31:02.002917
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "473a1a7ca408"
|
||||
down_revision = "325975216eb3"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
default_models_by_provider = {
|
||||
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],
|
||||
"bedrock": [
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
],
|
||||
"anthropic": ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620"],
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("display_model_names", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
|
||||
connection = op.get_bind()
|
||||
for provider, models in default_models_by_provider.items():
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"UPDATE llm_provider SET display_model_names = :models WHERE provider = :provider"
|
||||
),
|
||||
{"models": models, "provider": provider},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "display_model_names")
|
||||
@@ -1,61 +0,0 @@
|
||||
"""Add support for custom tools
|
||||
|
||||
Revision ID: 48d14957fe80
|
||||
Revises: b85f02ec1308
|
||||
Create Date: 2024-06-09 14:58:19.946509
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "48d14957fe80"
|
||||
down_revision = "b85f02ec1308"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"openapi_schema",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.create_foreign_key("tool_user_fk", "tool", "user", ["user_id"], ["id"])
|
||||
|
||||
op.create_table(
|
||||
"tool_call",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("tool_id", sa.Integer(), nullable=False),
|
||||
sa.Column("tool_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"tool_arguments", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"tool_result", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"message_id", sa.Integer(), sa.ForeignKey("chat_message.id"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("tool_call")
|
||||
|
||||
op.drop_constraint("tool_user_fk", "tool", type_="foreignkey")
|
||||
op.drop_column("tool", "user_id")
|
||||
op.drop_column("tool", "openapi_schema")
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Moved status to connector credential pair
|
||||
|
||||
Revision ID: 4a951134c801
|
||||
Revises: 7477a5f5d728
|
||||
Create Date: 2024-08-10 19:20:34.527559
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4a951134c801"
|
||||
down_revision = "7477a5f5d728"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"ACTIVE",
|
||||
"PAUSED",
|
||||
"DELETING",
|
||||
name="connectorcredentialpairstatus",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update status of connector_credential_pair based on connector's disabled status
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector_credential_pair
|
||||
SET status = CASE
|
||||
WHEN (
|
||||
SELECT disabled
|
||||
FROM connector
|
||||
WHERE connector.id = connector_credential_pair.connector_id
|
||||
) = FALSE THEN 'ACTIVE'
|
||||
ELSE 'PAUSED'
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the status column not nullable after setting values
|
||||
op.alter_column("connector_credential_pair", "status", nullable=False)
|
||||
|
||||
op.drop_column("connector", "disabled")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column("disabled", sa.BOOLEAN(), autoincrement=False, nullable=True),
|
||||
)
|
||||
|
||||
# Update disabled status of connector based on connector_credential_pair's status
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET disabled = CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1
|
||||
FROM connector_credential_pair
|
||||
WHERE connector_credential_pair.connector_id = connector.id
|
||||
AND connector_credential_pair.status = 'ACTIVE'
|
||||
) THEN FALSE
|
||||
ELSE TRUE
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the disabled column not nullable after setting values
|
||||
op.alter_column("connector", "disabled", nullable=False)
|
||||
|
||||
op.drop_column("connector_credential_pair", "status")
|
||||
@@ -1,34 +0,0 @@
|
||||
"""change default prune_freq
|
||||
|
||||
Revision ID: 4b08d97e175a
|
||||
Revises: d9ec13955951
|
||||
Create Date: 2024-08-20 15:28:52.993827
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4b08d97e175a"
|
||||
down_revision = "d9ec13955951"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET prune_freq = 2592000
|
||||
WHERE prune_freq = 86400
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET prune_freq = 86400
|
||||
WHERE prune_freq = 2592000
|
||||
"""
|
||||
)
|
||||
@@ -1,72 +0,0 @@
|
||||
"""Add type to credentials
|
||||
|
||||
Revision ID: 4ea2c93919c1
|
||||
Revises: 473a1a7ca408
|
||||
Create Date: 2024-07-18 13:07:13.655895
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4ea2c93919c1"
|
||||
down_revision = "473a1a7ca408"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new 'source' column to the 'credential' table
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"source",
|
||||
sa.String(length=100), # Use String instead of Enum
|
||||
nullable=True, # Initially allow NULL values
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"name",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create a temporary table that maps each credential to a single connector source.
|
||||
# This is needed because a credential can be associated with multiple connectors,
|
||||
# but we want to assign a single source to each credential.
|
||||
# We use DISTINCT ON to ensure we only get one row per credential_id.
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TEMPORARY TABLE temp_connector_credential AS
|
||||
SELECT DISTINCT ON (cc.credential_id)
|
||||
cc.credential_id,
|
||||
c.source AS connector_source
|
||||
FROM connector_credential_pair cc
|
||||
JOIN connector c ON cc.connector_id = c.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Update the 'source' column in the 'credential' table
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE credential cred
|
||||
SET source = COALESCE(
|
||||
(SELECT connector_source
|
||||
FROM temp_connector_credential temp
|
||||
WHERE cred.id = temp.credential_id),
|
||||
'NOT_APPLICABLE'
|
||||
)
|
||||
"""
|
||||
)
|
||||
# If no exception was raised, alter the column
|
||||
op.alter_column("credential", "source", nullable=True) # TODO modify
|
||||
# # ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("credential", "source")
|
||||
op.drop_column("credential", "name")
|
||||
@@ -1,66 +0,0 @@
|
||||
"""Add last synced and last modified to document table
|
||||
|
||||
Revision ID: 52a219fb5233
|
||||
Revises: f17bf3b0d9f1
|
||||
Create Date: 2024-08-28 17:40:46.077470
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "52a219fb5233"
|
||||
down_revision = "f7e58d357687"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# last modified represents the last time anything needing syncing to vespa changed
|
||||
# including row metadata and the document itself. This obviously does not include
|
||||
# the last_synced column.
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"last_modified",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# last synced represents the last time this document was synced to Vespa
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
# Set last_synced to the same value as last_modified for existing rows
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE document
|
||||
SET last_synced = last_modified
|
||||
"""
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_document_last_modified"),
|
||||
"document",
|
||||
["last_modified"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_document_last_synced"),
|
||||
"document",
|
||||
["last_synced"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_document_last_synced"), table_name="document")
|
||||
op.drop_index(op.f("ix_document_last_modified"), table_name="document")
|
||||
op.drop_column("document", "last_synced")
|
||||
op.drop_column("document", "last_modified")
|
||||
@@ -1,25 +0,0 @@
|
||||
"""hybrid-enum
|
||||
|
||||
Revision ID: 5fc1f54cc252
|
||||
Revises: 1d6ad76d1f37
|
||||
Create Date: 2024-08-06 15:35:40.278485
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5fc1f54cc252"
|
||||
down_revision = "1d6ad76d1f37"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("persona", "search_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column("persona", sa.Column("search_type", sa.String(), nullable=True))
|
||||
op.execute("UPDATE persona SET search_type = 'SEMANTIC'")
|
||||
op.alter_column("persona", "search_type", nullable=False)
|
||||
@@ -1,68 +0,0 @@
|
||||
"""More Descriptive Filestore
|
||||
|
||||
Revision ID: 70f00c45c0f2
|
||||
Revises: 3879338f8ba1
|
||||
Create Date: 2024-05-17 17:51:41.926893
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "70f00c45c0f2"
|
||||
down_revision = "3879338f8ba1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("file_store", sa.Column("display_name", sa.String(), nullable=True))
|
||||
op.add_column(
|
||||
"file_store",
|
||||
sa.Column(
|
||||
"file_origin",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="connector", # Default to connector
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"file_store",
|
||||
sa.Column(
|
||||
"file_type", sa.String(), nullable=False, server_default="text/plain"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"file_store",
|
||||
sa.Column(
|
||||
"file_metadata",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE file_store
|
||||
SET file_origin = CASE
|
||||
WHEN file_name LIKE 'chat__%' THEN 'chat_upload'
|
||||
ELSE 'connector'
|
||||
END,
|
||||
file_name = CASE
|
||||
WHEN file_name LIKE 'chat__%' THEN SUBSTR(file_name, 7)
|
||||
ELSE file_name
|
||||
END,
|
||||
file_type = CASE
|
||||
WHEN file_name LIKE 'chat__%' THEN 'image/png'
|
||||
ELSE 'text/plain'
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("file_store", "file_metadata")
|
||||
op.drop_column("file_store", "file_type")
|
||||
op.drop_column("file_store", "file_origin")
|
||||
op.drop_column("file_store", "display_name")
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Added model defaults for users
|
||||
|
||||
Revision ID: 7477a5f5d728
|
||||
Revises: 213fd978c6d8
|
||||
Create Date: 2024-08-04 19:00:04.512634
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7477a5f5d728"
|
||||
down_revision = "213fd978c6d8"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("user", sa.Column("default_model", sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "default_model")
|
||||
@@ -28,9 +28,5 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_unique_constraint(
|
||||
"connector_credential_pair__name__key", "connector_credential_pair", ["name"]
|
||||
)
|
||||
op.alter_column(
|
||||
"connector_credential_pair", "name", existing_type=sa.String(), nullable=True
|
||||
)
|
||||
# This wasn't really required by the code either, no good reason to make it unique again
|
||||
pass
|
||||
|
||||
@@ -10,7 +10,7 @@ import sqlalchemy as sa
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""add_llm_group_permissions_control
|
||||
|
||||
Revision ID: 795b20b85b4b
|
||||
Revises: 05c07bf07c00
|
||||
Create Date: 2024-07-19 11:54:35.701558
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "795b20b85b4b"
|
||||
down_revision = "05c07bf07c00"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_provider__user_group",
|
||||
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["llm_provider_id"],
|
||||
["llm_provider.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("llm_provider__user_group")
|
||||
op.drop_column("llm_provider", "is_public")
|
||||
@@ -1,35 +0,0 @@
|
||||
"""added slack_auto_filter
|
||||
|
||||
Revision ID: 7aea705850d5
|
||||
Revises: 4505fd7302e1
|
||||
Create Date: 2024-07-10 11:01:23.581015
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "7aea705850d5"
|
||||
down_revision = "4505fd7302e1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"slack_bot_config",
|
||||
sa.Column("enable_auto_filters", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE slack_bot_config SET enable_auto_filters = FALSE WHERE enable_auto_filters IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"slack_bot_config",
|
||||
"enable_auto_filters",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("slack_bot_config", "enable_auto_filters")
|
||||
@@ -1,107 +0,0 @@
|
||||
"""associate index attempts with ccpair
|
||||
|
||||
Revision ID: 8a87bd6ec550
|
||||
Revises: 4ea2c93919c1
|
||||
Create Date: 2024-07-22 15:15:52.558451
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8a87bd6ec550"
|
||||
down_revision = "4ea2c93919c1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new connector_credential_pair_id column
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
# Create a foreign key constraint to the connector_credential_pair table
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
"connector_credential_pair",
|
||||
["connector_credential_pair_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Populate the new connector_credential_pair_id column using existing connector_id and credential_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_credential_pair_id = (
|
||||
SELECT id FROM connector_credential_pair ccp
|
||||
WHERE
|
||||
(ia.connector_id IS NULL OR ccp.connector_id = ia.connector_id)
|
||||
AND (ia.credential_id IS NULL OR ccp.credential_id = ia.credential_id)
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE ia.connector_id IS NOT NULL OR ia.credential_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# For good measure
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM index_attempt
|
||||
WHERE connector_credential_pair_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the new connector_credential_pair_id column non-nullable
|
||||
op.alter_column("index_attempt", "connector_credential_pair_id", nullable=False)
|
||||
|
||||
# Drop the old connector_id and credential_id columns
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
|
||||
# Update the index to use connector_credential_pair_id
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_credential_pair_id", "time_created"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old connector_id and credential_id columns
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the old connector_id and credential_id columns using the connector_credential_pair_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_id = ccp.connector_id, credential_id = ccp.credential_id
|
||||
FROM connector_credential_pair ccp
|
||||
WHERE ia.connector_credential_pair_id = ccp.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the old connector_id and credential_id columns non-nullable
|
||||
op.alter_column("index_attempt", "connector_id", nullable=False)
|
||||
op.alter_column("index_attempt", "credential_id", nullable=False)
|
||||
|
||||
# Drop the new connector_credential_pair_id column
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_column("index_attempt", "connector_credential_pair_id")
|
||||
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_id", "credential_id", "time_created"],
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""add expiry time
|
||||
|
||||
Revision ID: 91ffac7e65b3
|
||||
Revises: bc9771dccadf
|
||||
Create Date: 2024-06-24 09:39:56.462242
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91ffac7e65b3"
|
||||
down_revision = "795b20b85b4b"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "oidc_expiry")
|
||||
@@ -1,158 +0,0 @@
|
||||
"""migration confluence to be explicit
|
||||
|
||||
Revision ID: a3795dce87be
|
||||
Revises: 1f60f60c3401
|
||||
Create Date: 2024-09-01 13:52:12.006740
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.sql import table, column
|
||||
|
||||
revision = "a3795dce87be"
|
||||
down_revision = "1f60f60c3401"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/spaces')[0]}"
|
||||
path_parts = parsed_url.path.split("/")
|
||||
space = path_parts[3]
|
||||
page_id = path_parts[5] if len(path_parts) > 5 else ""
|
||||
return wiki_base, space, page_id
|
||||
|
||||
def _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url: str,
|
||||
) -> tuple[str, str, str]:
|
||||
DISPLAY = "/display/"
|
||||
PAGE = "/pages/"
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split(DISPLAY)[0]}"
|
||||
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
|
||||
page_id = ""
|
||||
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
|
||||
page_id = content[1]
|
||||
return wiki_base, space, page_id
|
||||
|
||||
is_confluence_cloud = (
|
||||
".atlassian.net/wiki/spaces/" in wiki_url
|
||||
or ".jira.com/wiki/spaces/" in wiki_url
|
||||
)
|
||||
|
||||
if is_confluence_cloud:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(wiki_url)
|
||||
else:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url
|
||||
)
|
||||
|
||||
return wiki_base, space, page_id, is_confluence_cloud
|
||||
|
||||
|
||||
def reconstruct_confluence_url(
|
||||
wiki_base: str, space: str, page_id: str, is_cloud: bool
|
||||
) -> str:
|
||||
if is_cloud:
|
||||
url = f"{wiki_base}/spaces/{space}"
|
||||
if page_id:
|
||||
url += f"/pages/{page_id}"
|
||||
else:
|
||||
url = f"{wiki_base}/display/{space}"
|
||||
if page_id:
|
||||
url += f"/pages/{page_id}"
|
||||
return url
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
connector = table(
|
||||
"connector",
|
||||
column("id", sa.Integer),
|
||||
column("source", sa.String()),
|
||||
column("input_type", sa.String()),
|
||||
column("connector_specific_config", postgresql.JSONB),
|
||||
)
|
||||
|
||||
# Fetch all Confluence connectors
|
||||
connection = op.get_bind()
|
||||
confluence_connectors = connection.execute(
|
||||
sa.select(connector).where(
|
||||
sa.and_(
|
||||
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
|
||||
)
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row in confluence_connectors:
|
||||
config = row.connector_specific_config
|
||||
wiki_page_url = config["wiki_page_url"]
|
||||
wiki_base, space, page_id, is_cloud = extract_confluence_keys_from_url(
|
||||
wiki_page_url
|
||||
)
|
||||
|
||||
new_config = {
|
||||
"wiki_base": wiki_base,
|
||||
"space": space,
|
||||
"page_id": page_id,
|
||||
"is_cloud": is_cloud,
|
||||
}
|
||||
|
||||
for key, value in config.items():
|
||||
if key not in ["wiki_page_url"]:
|
||||
new_config[key] = value
|
||||
|
||||
op.execute(
|
||||
connector.update()
|
||||
.where(connector.c.id == row.id)
|
||||
.values(connector_specific_config=new_config)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
connector = table(
|
||||
"connector",
|
||||
column("id", sa.Integer),
|
||||
column("source", sa.String()),
|
||||
column("input_type", sa.String()),
|
||||
column("connector_specific_config", postgresql.JSONB),
|
||||
)
|
||||
|
||||
confluence_connectors = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.select(connector).where(
|
||||
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
|
||||
)
|
||||
)
|
||||
.fetchall()
|
||||
)
|
||||
|
||||
for row in confluence_connectors:
|
||||
config = row.connector_specific_config
|
||||
if all(key in config for key in ["wiki_base", "space", "is_cloud"]):
|
||||
wiki_page_url = reconstruct_confluence_url(
|
||||
config["wiki_base"],
|
||||
config["space"],
|
||||
config.get("page_id", ""),
|
||||
config["is_cloud"],
|
||||
)
|
||||
|
||||
new_config = {"wiki_page_url": wiki_page_url}
|
||||
new_config.update(
|
||||
{
|
||||
k: v
|
||||
for k, v in config.items()
|
||||
if k not in ["wiki_base", "space", "page_id", "is_cloud"]
|
||||
}
|
||||
)
|
||||
|
||||
op.execute(
|
||||
connector.update()
|
||||
.where(connector.c.id == row.id)
|
||||
.values(connector_specific_config=new_config)
|
||||
)
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Add chosen_assistants to User table
|
||||
|
||||
Revision ID: a3bfd0d64902
|
||||
Revises: ec85f2b3c544
|
||||
Create Date: 2024-05-26 17:22:24.834741
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3bfd0d64902"
|
||||
down_revision = "ec85f2b3c544"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
@@ -16,6 +16,7 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -28,9 +29,11 @@ def upgrade() -> None:
|
||||
),
|
||||
nullable=True,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -43,3 +46,4 @@ def downgrade() -> None:
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""fix-file-type-migration
|
||||
|
||||
Revision ID: b85f02ec1308
|
||||
Revises: a3bfd0d64902
|
||||
Create Date: 2024-05-31 18:09:26.658164
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b85f02ec1308"
|
||||
down_revision = "a3bfd0d64902"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE file_store
|
||||
SET file_origin = UPPER(file_origin)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Let's not break anything on purpose :)
|
||||
pass
|
||||
@@ -1,23 +0,0 @@
|
||||
"""backfill is_internet data to False
|
||||
|
||||
Revision ID: b896bbd0d5a7
|
||||
Revises: 44f856ae2a4a
|
||||
Create Date: 2024-07-16 15:21:05.718571
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b896bbd0d5a7"
|
||||
down_revision = "44f856ae2a4a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("UPDATE search_doc SET is_internet = FALSE WHERE is_internet IS NULL")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -1,26 +0,0 @@
|
||||
"""add support for litellm proxy in reranking
|
||||
|
||||
Revision ID: ba98eba0f66a
|
||||
Revises: bceb1e139447
|
||||
Create Date: 2024-09-06 10:36:04.507332
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ba98eba0f66a"
|
||||
down_revision = "bceb1e139447"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_settings", "rerank_api_url")
|
||||
@@ -1,51 +0,0 @@
|
||||
"""create usage reports table
|
||||
|
||||
Revision ID: bc9771dccadf
|
||||
Revises: 0568ccf46a6b
|
||||
Create Date: 2024-06-18 10:04:26.800282
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bc9771dccadf"
|
||||
down_revision = "0568ccf46a6b"
|
||||
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"usage_reports",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("report_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"requestor_user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("period_from", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("period_to", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["report_name"],
|
||||
["file_store.file_name"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["requestor_user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("usage_reports")
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Add base_url to CloudEmbeddingProvider
|
||||
|
||||
Revision ID: bceb1e139447
|
||||
Revises: a3795dce87be
|
||||
Create Date: 2024-08-28 17:00:52.554580
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bceb1e139447"
|
||||
down_revision = "a3795dce87be"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("embedding_provider", "api_url")
|
||||
@@ -1,75 +0,0 @@
|
||||
"""Add standard_answer tables
|
||||
|
||||
Revision ID: c18cdf4b497e
|
||||
Revises: 3a7802814195
|
||||
Create Date: 2024-06-06 15:15:02.000648
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c18cdf4b497e"
|
||||
down_revision = "3a7802814195"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"standard_answer",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("keyword", sa.String(), nullable=False),
|
||||
sa.Column("answer", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("keyword"),
|
||||
)
|
||||
op.create_table(
|
||||
"standard_answer_category",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
op.create_table(
|
||||
"standard_answer__standard_answer_category",
|
||||
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
|
||||
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_category_id"],
|
||||
["standard_answer_category.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_id"],
|
||||
["standard_answer.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("standard_answer_id", "standard_answer_category_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
sa.Column("slack_bot_config_id", sa.Integer(), nullable=False),
|
||||
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["slack_bot_config_id"],
|
||||
["slack_bot_config.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_category_id"],
|
||||
["standard_answer_category.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("slack_bot_config_id", "standard_answer_category_id"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("slack_thread_id", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "slack_thread_id")
|
||||
|
||||
op.drop_table("slack_bot_config__standard_answer_category")
|
||||
op.drop_table("standard_answer__standard_answer_category")
|
||||
op.drop_table("standard_answer_category")
|
||||
op.drop_table("standard_answer")
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Add index_attempt_errors table
|
||||
|
||||
Revision ID: c5b692fa265c
|
||||
Revises: 4a951134c801
|
||||
Create Date: 2024-08-08 14:06:39.581972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c5b692fa265c"
|
||||
down_revision = "4a951134c801"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"index_attempt_errors",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
|
||||
sa.Column("batch", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"doc_summaries",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column("traceback", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["index_attempt_id"],
|
||||
["index_attempt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"index_attempt_id",
|
||||
"index_attempt_errors",
|
||||
["time_created"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
|
||||
op.drop_table("index_attempt_errors")
|
||||
# ### end Alembic commands ###
|
||||
@@ -19,9 +19,6 @@ depends_on: None = None
|
||||
def upgrade() -> None:
|
||||
op.drop_table("deletion_attempt")
|
||||
|
||||
# Remove the DeletionStatus enum
|
||||
op.execute("DROP TYPE IF EXISTS deletionstatus;")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
"""combined slack id fields
|
||||
|
||||
Revision ID: d716b0791ddd
|
||||
Revises: 7aea705850d5
|
||||
Create Date: 2024-07-10 17:57:45.630550
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d716b0791ddd"
|
||||
down_revision = "7aea705850d5"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_bot_config
|
||||
SET channel_config = jsonb_set(
|
||||
channel_config,
|
||||
'{respond_member_group_list}',
|
||||
coalesce(channel_config->'respond_team_member_list', '[]'::jsonb) ||
|
||||
coalesce(channel_config->'respond_slack_group_list', '[]'::jsonb)
|
||||
) - 'respond_team_member_list' - 'respond_slack_group_list'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_bot_config
|
||||
SET channel_config = jsonb_set(
|
||||
jsonb_set(
|
||||
channel_config - 'respond_member_group_list',
|
||||
'{respond_team_member_list}',
|
||||
'[]'::jsonb
|
||||
),
|
||||
'{respond_slack_group_list}',
|
||||
'[]'::jsonb
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Remove _alt suffix from model_name
|
||||
|
||||
Revision ID: d9ec13955951
|
||||
Revises: da4c21c69164
|
||||
Create Date: 2024-08-20 16:31:32.955686
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d9ec13955951"
|
||||
down_revision = "da4c21c69164"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_model
|
||||
SET model_name = regexp_replace(model_name, '__danswer_alt_index$', '')
|
||||
WHERE model_name LIKE '%__danswer_alt_index'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We can't reliably add the __danswer_alt_index suffix back, so we'll leave this empty
|
||||
pass
|
||||
@@ -1,65 +0,0 @@
|
||||
"""chosen_assistants changed to jsonb
|
||||
|
||||
Revision ID: da4c21c69164
|
||||
Revises: c5b692fa265c
|
||||
Create Date: 2024-08-18 19:06:47.291491
|
||||
|
||||
"""
|
||||
import json
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "da4c21c69164"
|
||||
down_revision = "c5b692fa265c"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chosen_assistants",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
),
|
||||
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
|
||||
)
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
),
|
||||
{"chosen_assistants": chosen_assistants, "id": id},
|
||||
)
|
||||
@@ -9,7 +9,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from danswer.db.search_settings import (
|
||||
from danswer.db.embedding_model import (
|
||||
get_new_default_embedding_model,
|
||||
get_old_default_embedding_model,
|
||||
user_has_overridden_embedding_model,
|
||||
@@ -71,14 +71,14 @@ def upgrade() -> None:
|
||||
"query_prefix": old_embedding_model.query_prefix,
|
||||
"passage_prefix": old_embedding_model.passage_prefix,
|
||||
"index_name": old_embedding_model.index_name,
|
||||
"status": IndexModelStatus.PRESENT,
|
||||
"status": old_embedding_model.status,
|
||||
}
|
||||
],
|
||||
)
|
||||
# if the user has not overridden the default embedding model via env variables,
|
||||
# insert the new default model into the database to auto-upgrade them
|
||||
if not user_has_overridden_embedding_model():
|
||||
new_embedding_model = get_new_default_embedding_model()
|
||||
new_embedding_model = get_new_default_embedding_model(is_present=False)
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
@@ -136,4 +136,4 @@ def downgrade() -> None:
|
||||
)
|
||||
op.drop_column("index_attempt", "embedding_model_id")
|
||||
op.drop_table("embedding_model")
|
||||
op.execute("DROP TYPE IF EXISTS indexmodelstatus;")
|
||||
op.execute("DROP TYPE indexmodelstatus;")
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Added input prompts
|
||||
|
||||
Revision ID: e1392f05e840
|
||||
Revises: 08a1eda20fe1
|
||||
Create Date: 2024-07-13 19:09:22.556224
|
||||
|
||||
"""
|
||||
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e1392f05e840"
|
||||
down_revision = "08a1eda20fe1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
@@ -1,22 +0,0 @@
|
||||
"""added-prune-frequency
|
||||
|
||||
Revision ID: e209dc5a8156
|
||||
Revises: 48d14957fe80
|
||||
Create Date: 2024-06-16 16:02:35.273231
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "e209dc5a8156"
|
||||
down_revision = "48d14957fe80"
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("connector", sa.Column("prune_freq", sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector", "prune_freq")
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Remove Last Attempt Status from CC Pair
|
||||
|
||||
Revision ID: ec85f2b3c544
|
||||
Revises: 3879338f8ba1
|
||||
Create Date: 2024-05-23 21:39:46.126010
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ec85f2b3c544"
|
||||
down_revision = "70f00c45c0f2"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "last_attempt_status")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_attempt_status",
|
||||
sa.VARCHAR(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Added alternate model to chat message
|
||||
|
||||
Revision ID: ee3f4b47fad5
|
||||
Revises: 2d2304e27d8c
|
||||
Create Date: 2024-08-12 00:11:50.915845
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ee3f4b47fad5"
|
||||
down_revision = "2d2304e27d8c"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("overridden_model", sa.String(length=255), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "overridden_model")
|
||||
@@ -1,172 +0,0 @@
|
||||
"""embedding provider by provider type
|
||||
|
||||
Revision ID: f17bf3b0d9f1
|
||||
Revises: 351faebd379d
|
||||
Create Date: 2024-08-21 13:13:31.120460
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f17bf3b0d9f1"
|
||||
down_revision = "351faebd379d"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add provider_type column to embedding_provider
|
||||
op.add_column(
|
||||
"embedding_provider",
|
||||
sa.Column("provider_type", sa.String(50), nullable=True),
|
||||
)
|
||||
|
||||
# Update provider_type with existing name values
|
||||
op.execute("UPDATE embedding_provider SET provider_type = UPPER(name)")
|
||||
|
||||
# Make provider_type not nullable
|
||||
op.alter_column("embedding_provider", "provider_type", nullable=False)
|
||||
|
||||
# Drop the foreign key constraint in embedding_model table
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Drop the existing primary key constraint
|
||||
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
|
||||
|
||||
# Create a new primary key constraint on provider_type
|
||||
op.create_primary_key(
|
||||
"embedding_provider_pkey", "embedding_provider", ["provider_type"]
|
||||
)
|
||||
|
||||
# Add provider_type column to embedding_model
|
||||
op.add_column(
|
||||
"embedding_model",
|
||||
sa.Column("provider_type", sa.String(50), nullable=True),
|
||||
)
|
||||
|
||||
# Update provider_type for existing embedding models
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_model
|
||||
SET provider_type = (
|
||||
SELECT provider_type
|
||||
FROM embedding_provider
|
||||
WHERE embedding_provider.id = embedding_model.cloud_provider_id
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the old id column from embedding_provider
|
||||
op.drop_column("embedding_provider", "id")
|
||||
|
||||
# Drop the name column from embedding_provider
|
||||
op.drop_column("embedding_provider", "name")
|
||||
|
||||
# Drop the default_model_id column from embedding_provider
|
||||
op.drop_column("embedding_provider", "default_model_id")
|
||||
|
||||
# Drop the old cloud_provider_id column from embedding_model
|
||||
op.drop_column("embedding_model", "cloud_provider_id")
|
||||
|
||||
# Create the new foreign key constraint
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["provider_type"],
|
||||
["provider_type"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the foreign key constraint in embedding_model table
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Add back the cloud_provider_id column to embedding_model
|
||||
op.add_column(
|
||||
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column("embedding_provider", sa.Column("id", sa.Integer(), nullable=True))
|
||||
|
||||
# Assign incrementing IDs to embedding providers
|
||||
op.execute(
|
||||
"""
|
||||
CREATE SEQUENCE IF NOT EXISTS embedding_provider_id_seq;"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_provider SET id = nextval('embedding_provider_id_seq');
|
||||
"""
|
||||
)
|
||||
|
||||
# Update cloud_provider_id based on provider_type
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_model
|
||||
SET cloud_provider_id = CASE
|
||||
WHEN provider_type IS NULL THEN NULL
|
||||
ELSE (
|
||||
SELECT id
|
||||
FROM embedding_provider
|
||||
WHERE embedding_provider.provider_type = embedding_model.provider_type
|
||||
)
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the provider_type column from embedding_model
|
||||
op.drop_column("embedding_model", "provider_type")
|
||||
|
||||
# Add back the columns to embedding_provider
|
||||
op.add_column("embedding_provider", sa.Column("name", sa.String(50), nullable=True))
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("default_model_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Drop the existing primary key constraint on provider_type
|
||||
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
|
||||
|
||||
# Create the original primary key constraint on id
|
||||
op.create_primary_key("embedding_provider_pkey", "embedding_provider", ["id"])
|
||||
|
||||
# Update name with existing provider_type values
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_provider
|
||||
SET name = CASE
|
||||
WHEN provider_type = 'OPENAI' THEN 'OpenAI'
|
||||
WHEN provider_type = 'COHERE' THEN 'Cohere'
|
||||
WHEN provider_type = 'GOOGLE' THEN 'Google'
|
||||
WHEN provider_type = 'VOYAGE' THEN 'Voyage'
|
||||
ELSE provider_type
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the provider_type column from embedding_provider
|
||||
op.drop_column("embedding_provider", "provider_type")
|
||||
|
||||
# Recreate the foreign key constraint in embedding_model table
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["cloud_provider_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Recreate the foreign key constraint in embedding_model table
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_provider_default_model",
|
||||
"embedding_provider",
|
||||
"embedding_model",
|
||||
["default_model_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""add has_web_login column to user
|
||||
|
||||
Revision ID: f7e58d357687
|
||||
Revises: bceb1e139447
|
||||
Create Date: 2024-09-07 20:20:54.522620
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7e58d357687"
|
||||
down_revision = "ba98eba0f66a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "has_web_login")
|
||||
2
backend/assets/.gitignore
vendored
2
backend/assets/.gitignore
vendored
@@ -1,2 +0,0 @@
|
||||
*
|
||||
!.gitignore
|
||||
@@ -1,51 +1,25 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_access_info_for_document
|
||||
from danswer.db.document import get_access_info_for_documents
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
def _get_access_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DocumentAccess:
|
||||
info = get_access_info_for_document(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
if not info:
|
||||
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
|
||||
return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2])
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DocumentAccess:
|
||||
versioned_get_access_for_document_fn = fetch_versioned_implementation(
|
||||
"danswer.access.access", "_get_access_for_document"
|
||||
)
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
document_access_info = get_access_info_for_documents(
|
||||
document_access_info = get_acccess_info_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(
|
||||
user_ids=user_ids, user_groups=[], is_public=is_public
|
||||
)
|
||||
document_id: DocumentAccess.build(user_ids, is_public)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
}
|
||||
|
||||
@@ -53,16 +27,23 @@ def _get_access_for_documents(
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""Fetches all access information for the given documents."""
|
||||
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
|
||||
"danswer.access.access", "_get_access_for_documents"
|
||||
)
|
||||
return versioned_get_access_for_documents_fn(
|
||||
document_ids, db_session
|
||||
document_ids, db_session, cc_pair_to_delete
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
||||
@@ -1,30 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_user_group
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess:
|
||||
user_ids: set[str] # stringified UUIDs
|
||||
user_groups: set[str] # names of user groups associated with this document
|
||||
is_public: bool
|
||||
|
||||
def to_acl(self) -> list[str]:
|
||||
return (
|
||||
[prefix_user(user_id) for user_id in self.user_ids]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
)
|
||||
return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
|
||||
) -> "DocumentAccess":
|
||||
def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess":
|
||||
return cls(
|
||||
user_ids={str(user_id) for user_id in user_ids if user_id},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
|
||||
|
||||
def prefix_user_group(user_group_name: str) -> str:
|
||||
"""Prefixes a user group name to eliminate collision with user IDs.
|
||||
This assumes that user ids are prefixed with a different prefix."""
|
||||
return f"group:{user_group_name}"
|
||||
@@ -1,20 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_USER_STORE_KEY
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
|
||||
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_dynamic_config_store()
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
except ConfigNotFoundError:
|
||||
return list()
|
||||
|
||||
|
||||
def write_invited_users(emails: list[str]) -> int:
|
||||
store = get_dynamic_config_store()
|
||||
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
@@ -1,38 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
|
||||
from danswer.dynamic_configs.store import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.store import DynamicConfigStore
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.server.manage.models import UserPreferences
|
||||
|
||||
|
||||
def set_no_auth_user_preferences(
|
||||
store: DynamicConfigStore, preferences: UserPreferences
|
||||
) -> None:
|
||||
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
|
||||
|
||||
|
||||
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
|
||||
try:
|
||||
preferences_data = cast(
|
||||
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except ConfigNotFoundError:
|
||||
return UserPreferences(chosen_assistants=None, default_model=None)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
|
||||
return UserInfo(
|
||||
id="__no_auth_user__",
|
||||
email="anonymous@danswer.ai",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
)
|
||||
@@ -5,26 +5,8 @@ from fastapi_users import schemas
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""
|
||||
User roles
|
||||
- Basic can't perform any admin actions
|
||||
- Admin can perform all admin actions
|
||||
- Curator can perform admin actions for
|
||||
groups they are curators of
|
||||
- Global Curator can perform admin actions
|
||||
for all groups they are a member of
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
ADMIN = "admin"
|
||||
CURATOR = "curator"
|
||||
GLOBAL_CURATOR = "global_curator"
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
LIVE = "live"
|
||||
INVITED = "invited"
|
||||
DEACTIVATED = "deactivated"
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
@@ -33,9 +15,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
has_web_login: bool | None = True
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
has_web_login: bool | None = True
|
||||
|
||||
@@ -1,24 +1,19 @@
|
||||
import os
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from email_validator import EmailNotValidError
|
||||
from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import exceptions
|
||||
from fastapi_users import FastAPIUsers
|
||||
from fastapi_users import models
|
||||
from fastapi_users import schemas
|
||||
@@ -32,10 +27,8 @@ from fastapi_users.openapi import OpenAPIResponseType
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
@@ -45,7 +38,6 @@ from danswer.configs.app_configs import SMTP_PASS
|
||||
from danswer.configs.app_configs import SMTP_PORT
|
||||
from danswer.configs.app_configs import SMTP_SERVER
|
||||
from danswer.configs.app_configs import SMTP_USER
|
||||
from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
@@ -54,28 +46,21 @@ from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
if user and user.role == UserRole.ADMIN:
|
||||
return True
|
||||
return False
|
||||
USER_WHITELIST_FILE = "/home/danswer_whitelist.txt"
|
||||
_user_whitelist: list[str] | None = None
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
@@ -84,7 +69,7 @@ def verify_auth_setting() -> None:
|
||||
"User must choose a valid user authentication method: "
|
||||
"disabled, basic, or google_oauth"
|
||||
)
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
@@ -107,36 +92,22 @@ def user_needs_to_be_verified() -> bool:
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
return
|
||||
def get_user_whitelist() -> list[str]:
|
||||
global _user_whitelist
|
||||
if _user_whitelist is None:
|
||||
if os.path.exists(USER_WHITELIST_FILE):
|
||||
with open(USER_WHITELIST_FILE, "r") as file:
|
||||
_user_whitelist = [line.strip() for line in file]
|
||||
else:
|
||||
_user_whitelist = []
|
||||
|
||||
if not email:
|
||||
raise PermissionError("Email must be specified")
|
||||
|
||||
email_info = validate_email(email) # can raise EmailNotValidError
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
# normalized emails are now being inserted into the db
|
||||
# we can remove this normalization on read after some time has passed
|
||||
email_info_whitelist = validate_email(email_whitelist)
|
||||
except EmailNotValidError:
|
||||
continue
|
||||
|
||||
# oddly, normalization does not include lowercasing the user part of the
|
||||
# email address ... which we want to allow
|
||||
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
|
||||
return
|
||||
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
return _user_whitelist
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
whitelist = get_user_whitelist()
|
||||
if (whitelist and email not in whitelist) or not email:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
@@ -187,36 +158,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create: schemas.UC | UserCreate,
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
verify_email_is_invited(user_create.email)
|
||||
) -> models.UP:
|
||||
verify_email_in_whitelist(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0 or user_create.email in get_default_admin_user_emails():
|
||||
if user_count == 0:
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
return user
|
||||
return await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
|
||||
async def oauth_callback(
|
||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||
@@ -234,7 +185,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
verify_email_in_whitelist(account_email)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
user = await super().oauth_callback( # type: ignore
|
||||
return await super().oauth_callback( # type: ignore
|
||||
oauth_name=oauth_name,
|
||||
access_token=access_token,
|
||||
account_id=account_id,
|
||||
@@ -246,34 +197,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": None})
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.has_web_login:
|
||||
await self.user_db.update(
|
||||
user,
|
||||
update_dict={
|
||||
"is_verified": is_verified_by_default,
|
||||
"has_web_login": True,
|
||||
},
|
||||
)
|
||||
user.is_verified = is_verified_by_default
|
||||
user.has_web_login = True
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.notice(f"User {user.id} has registered.")
|
||||
logger.info(f"User {user.id} has registered.")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.SIGN_UP,
|
||||
data={"action": "create"},
|
||||
@@ -283,35 +210,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
logger.info(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
verify_email_domain(user.email)
|
||||
|
||||
logger.notice(
|
||||
logger.info(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}"
|
||||
)
|
||||
|
||||
send_user_verification_email(user.email, token)
|
||||
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[User]:
|
||||
user = await super().authenticate(credentials)
|
||||
if user is None:
|
||||
try:
|
||||
user = await self.get_by_email(credentials.username)
|
||||
if not user.has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
pass
|
||||
return user
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
@@ -328,12 +239,10 @@ cookie_transport = CookieTransport(
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
strategy = DatabaseStrategy(
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="database",
|
||||
@@ -430,12 +339,6 @@ async def double_check_user(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -445,28 +348,6 @@ async def current_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_curator_or_admin_user(
|
||||
user: User | None = Depends(current_user),
|
||||
) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated or lacks role information.",
|
||||
)
|
||||
|
||||
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
|
||||
if user.role not in allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not a curator or admin.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
@@ -474,12 +355,6 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
detail="Access denied. User is not an admin.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Danswer MIT
|
||||
return []
|
||||
|
||||
212
backend/danswer/background/celery/celery.py
Normal file
212
backend/danswer/background/celery/celery.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from datetime import timedelta
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.document_set import fetch_documents_for_document_set_paginated
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import SYNC_DB_API
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_live_task_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
connection_string = build_connection_string(db_api=SYNC_DB_API)
|
||||
celery_broker_url = f"sqla+{connection_string}"
|
||||
celery_backend_url = f"db+{connection_string}"
|
||||
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
|
||||
|
||||
|
||||
_SYNC_BATCH_SIZE = 100
|
||||
|
||||
|
||||
#####
|
||||
# Tasks that need to be run in job queue, registered via APIs
|
||||
#
|
||||
# If imports from this module are needed, use local imports to avoid circular importing
|
||||
#####
|
||||
@build_celery_task_wrapper(name_cc_cleanup_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def cleanup_connector_credential_pair_task(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> int:
|
||||
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
|
||||
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
|
||||
or updating the ACL"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
# validate that the connector / credential pair is deletable
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(
|
||||
f"Cannot run deletion attempt - connector_credential_pair with Connector ID: "
|
||||
f"{connector_id} and Credential ID: {credential_id} does not exist."
|
||||
)
|
||||
|
||||
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(cc_pair)
|
||||
if deletion_attempt_disallowed_reason:
|
||||
raise ValueError(deletion_attempt_disallowed_reason)
|
||||
|
||||
try:
|
||||
# The bulk of the work is in here, updates Postgres and Vespa
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
return delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run connector_deletion due to {e}")
|
||||
raise e
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_document_set_sync_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_document_set_task(document_set_id: int) -> None:
|
||||
"""For document sets marked as not up to date, sync the state from postgres
|
||||
into the datastore. Also handles deletions."""
|
||||
|
||||
def _sync_document_batch(document_ids: list[str], db_session: Session) -> None:
|
||||
logger.debug(f"Syncing document sets for: {document_ids}")
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
with prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
):
|
||||
# get current state of document sets for these documents
|
||||
document_set_map = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
}
|
||||
|
||||
# update Vespa
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=set(document_set_map.get(document_id, [])),
|
||||
)
|
||||
for document_id in document_ids
|
||||
]
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
cursor = None
|
||||
while True:
|
||||
document_batch, cursor = fetch_documents_for_document_set_paginated(
|
||||
document_set_id=document_set_id,
|
||||
db_session=db_session,
|
||||
current_only=False,
|
||||
last_document_id=cursor,
|
||||
limit=_SYNC_BATCH_SIZE,
|
||||
)
|
||||
_sync_document_batch(
|
||||
document_ids=[document.id for document in document_batch],
|
||||
db_session=db_session,
|
||||
)
|
||||
if cursor is None:
|
||||
break
|
||||
|
||||
# if there are no connectors, then delete the document set. Otherwise, just
|
||||
# mark it as successfully synced.
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(
|
||||
db_session=db_session, document_set_id=document_set_id
|
||||
),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if not document_set.connector_credential_pairs:
|
||||
delete_document_set(
|
||||
document_set_row=document_set, db_session=db_session
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(
|
||||
document_set_id=document_set_id, db_session=db_session
|
||||
)
|
||||
logger.info(f"Document set sync for '{document_set_id}' complete!")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to sync document set %s", document_set_id)
|
||||
raise
|
||||
|
||||
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_for_document_sets_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_document_sets_sync_task() -> None:
|
||||
"""Runs periodically to check if any document sets are out of sync
|
||||
Creates a task to sync the set if needed"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
if not document_set.is_up_to_date:
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_live_task_not_timed_out(
|
||||
latest_sync, db_session
|
||||
):
|
||||
logger.info(
|
||||
f"Document set '{document_set.id}' is already syncing. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
sync_document_set_task.apply_async(
|
||||
kwargs=dict(document_set_id=document_set.id),
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"check-for-document-set-sync": {
|
||||
"task": "check_for_document_sets_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
}
|
||||
@@ -1,921 +0,0 @@
|
||||
import json
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.contrib.abortable import AbortableTask # type: ignore
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.exceptions import TaskRevokedError
|
||||
from celery.signals import beat_init
|
||||
from celery.signals import worker_init
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from redis import Redis
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair
|
||||
from danswer.background.celery.celery_utils import should_prune_cc_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import (
|
||||
get_connector_credential_pair,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import count_documents_by_needs_sync
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import fetch_document_set_for_document
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.redis.redis_pool import RedisPool
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
redis_pool = RedisPool()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object(
|
||||
"danswer.background.celery.celeryconfig"
|
||||
) # Load configuration from 'celeryconfig.py'
|
||||
|
||||
|
||||
#####
|
||||
# Tasks that need to be run in job queue, registered via APIs
|
||||
#
|
||||
# If imports from this module are needed, use local imports to avoid circular importing
|
||||
#####
|
||||
@build_celery_task_wrapper(name_cc_cleanup_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def cleanup_connector_credential_pair_task(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> int:
|
||||
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
|
||||
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
|
||||
or updating the ACL"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
# validate that the connector / credential pair is deletable
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(
|
||||
f"Cannot run deletion attempt - connector_credential_pair with Connector ID: "
|
||||
f"{connector_id} and Credential ID: {credential_id} does not exist."
|
||||
)
|
||||
|
||||
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair, db_session=db_session
|
||||
)
|
||||
if deletion_attempt_disallowed_reason:
|
||||
raise ValueError(deletion_attempt_disallowed_reason)
|
||||
|
||||
try:
|
||||
# The bulk of the work is in here, updates Postgres and Vespa
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
return delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run connector_deletion. "
|
||||
f"connector_id={connector_id} credential_id={credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_cc_prune_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"ccpair not found for {connector_id} {credential_id}"
|
||||
)
|
||||
return
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
cc_pair.connector.source,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
db_session,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector
|
||||
)
|
||||
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
if len(doc_ids_to_remove) == 0:
|
||||
task_logger.info(
|
||||
f"No docs to prune from {cc_pair.connector.source} connector"
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
|
||||
)
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning for connector id {connector_id}."
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
return None
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
stale_doc_count = count_documents_by_needs_sync(db_session)
|
||||
if stale_doc_count == 0:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||
)
|
||||
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
total_tasks_generated = 0
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
|
||||
if tasks_generated is None:
|
||||
continue
|
||||
|
||||
if tasks_generated == 0:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
total_tasks_generated += tasks_generated
|
||||
|
||||
task_logger.info(
|
||||
f"All per connector generate_tasks finished. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
|
||||
return total_tasks_generated
|
||||
|
||||
|
||||
def try_generate_document_set_sync_tasks(
|
||||
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(document_set.id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if r.exists(rds.fence_key):
|
||||
return None
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
if document_set.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rds.taskset_key)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks finished. "
|
||||
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rds.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def try_generate_user_group_sync_tasks(
|
||||
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(usergroup.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
if usergroup.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rug.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(f"generate_tasks starting. usergroup_id={usergroup.id}")
|
||||
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"generate_tasks finished. "
|
||||
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rug.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_vespa_sync_task() -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
r = redis_pool.get_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
|
||||
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
try_generate_document_set_sync_tasks(
|
||||
document_set, db_session, r, lock_beat
|
||||
)
|
||||
|
||||
# check if any user groups are not synced
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_groups"
|
||||
)
|
||||
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
for usergroup in user_groups:
|
||||
try_generate_user_group_sync_tasks(
|
||||
usergroup, db_session, r, lock_beat
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
# Always exceptions on the MIT version, which is expected
|
||||
pass
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_for_cc_pair_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_cc_pair_deletion_task() -> None:
|
||||
"""Runs periodically to check if any deletion tasks should be run"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any cc pairs are up for deletion
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_kick_off_deletion_of_cc_pair(cc_pair, db_session):
|
||||
task_logger.info(
|
||||
f"Deleting the {cc_pair.name} connector credential pair"
|
||||
)
|
||||
cleanup_connector_credential_pair_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="kombu_message_cleanup_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
|
||||
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
|
||||
|
||||
ctx = {}
|
||||
ctx["last_processed_id"] = 0
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
|
||||
).scalar()
|
||||
if not result:
|
||||
return 0
|
||||
|
||||
while True:
|
||||
if self.is_aborted():
|
||||
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
|
||||
|
||||
b = kombu_message_cleanup_task_helper(ctx, db_session)
|
||||
if not b:
|
||||
break
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if ctx["deleted"] > 0:
|
||||
task_logger.info(
|
||||
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
|
||||
)
|
||||
|
||||
return ctx["deleted"]
|
||||
|
||||
|
||||
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
|
||||
"""
|
||||
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
|
||||
|
||||
This function retrieves messages from the `kombu_message` table that are no longer visible and
|
||||
older than a specified interval. It checks if the corresponding task_id exists in the
|
||||
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
|
||||
|
||||
Args:
|
||||
ctx (dict): A context dictionary containing configuration parameters such as:
|
||||
- 'cleanup_age' (int): The age in days after which messages are considered old.
|
||||
- 'page_limit' (int): The maximum number of messages to process in one batch.
|
||||
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
|
||||
- 'deleted' (int): A counter to track the number of deleted messages.
|
||||
db_session (Session): The SQLAlchemy database session for executing queries.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if there are more rows to process, False if not.
|
||||
"""
|
||||
|
||||
inspector = inspect(db_session.bind)
|
||||
if not inspector:
|
||||
return False
|
||||
|
||||
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
|
||||
# We can fail silently.
|
||||
if not inspector.has_table("kombu_message"):
|
||||
return False
|
||||
|
||||
query = text(
|
||||
"""
|
||||
SELECT id, timestamp, payload
|
||||
FROM kombu_message WHERE visible = 'false'
|
||||
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
|
||||
AND id > :last_processed_id
|
||||
ORDER BY id
|
||||
LIMIT :page_limit
|
||||
"""
|
||||
)
|
||||
kombu_messages = db_session.execute(
|
||||
query,
|
||||
{
|
||||
"interval_days": f"{ctx['cleanup_age']} days",
|
||||
"page_limit": ctx["page_limit"],
|
||||
"last_processed_id": ctx["last_processed_id"],
|
||||
},
|
||||
).fetchall()
|
||||
|
||||
if len(kombu_messages) == 0:
|
||||
return False
|
||||
|
||||
for msg in kombu_messages:
|
||||
payload = json.loads(msg[2])
|
||||
task_id = payload["headers"]["id"]
|
||||
|
||||
# Check if task_id exists in celery_taskmeta
|
||||
task_exists = db_session.execute(
|
||||
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
|
||||
{"task_id": task_id},
|
||||
).fetchone()
|
||||
|
||||
# If task_id does not exist, delete the message
|
||||
if not task_exists:
|
||||
result = db_session.execute(
|
||||
text("DELETE FROM kombu_message WHERE id = :message_id"),
|
||||
{"message_id": msg[0]},
|
||||
)
|
||||
if result.rowcount > 0: # type: ignore
|
||||
ctx["deleted"] += 1
|
||||
|
||||
ctx["last_processed_id"] = msg[0]
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_for_prune_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_prune_task() -> None:
|
||||
"""Runs periodically to check if any prune tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
|
||||
for cc_pair in all_cc_pairs:
|
||||
if should_prune_cc_pair(
|
||||
connector=cc_pair.connector,
|
||||
credential=cc_pair.credential,
|
||||
db_session=db_session,
|
||||
):
|
||||
task_logger.info(f"Pruning the {cc_pair.connector.name} connector")
|
||||
|
||||
prune_documents_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# document set sync
|
||||
doc_sets = fetch_document_set_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
# User group sync
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
update_request = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
document_index.update(update_requests=[update_request])
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def celery_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
"""We handle this signal in order to remove completed tasks
|
||||
from their respective tasksets. This allows us to track the progress of document set
|
||||
and user group syncs.
|
||||
|
||||
This function runs after any task completes (both success and failure)
|
||||
Note that this signal does not fire on a task that failed to complete and is going
|
||||
to be retried.
|
||||
"""
|
||||
if not task:
|
||||
return
|
||||
|
||||
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
||||
# logger.debug(f"Result: {retval}")
|
||||
|
||||
if state not in READY_STATES:
|
||||
return
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r = redis_pool.get_client()
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
r = redis_pool.get_client()
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
r = redis_pool.get_client()
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
def monitor_connector_taskset(r: Redis) -> None:
|
||||
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
|
||||
task_logger.info(f"Stale documents: remaining={count} initial={initial_count}")
|
||||
if count == 0:
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
if document_set_id is None:
|
||||
task_logger.warning("could not parse document set id from {key}")
|
||||
return
|
||||
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
fence_value = r.get(rds.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"document_set_id={document_set_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if document_set:
|
||||
if not document_set.connector_credential_pairs:
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
|
||||
r.delete(rds.taskset_key)
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
|
||||
key = key_bytes.decode("utf-8")
|
||||
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
|
||||
if not usergroup_id:
|
||||
task_logger.warning("Could not parse usergroup id from {key}")
|
||||
return
|
||||
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
fence_value = r.get(rug.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rug.taskset_key))
|
||||
task_logger.info(
|
||||
f"usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
try:
|
||||
fetch_user_group = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_group"
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
task_logger.exception(
|
||||
"fetch_versioned_implementation failed to look up fetch_user_group."
|
||||
)
|
||||
return
|
||||
|
||||
user_group: UserGroup | None = fetch_user_group(
|
||||
db_session=db_session, user_group_id=usergroup_id
|
||||
)
|
||||
if user_group:
|
||||
if user_group.is_up_for_deletion:
|
||||
delete_user_group = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group", "delete_user_group", noop_fallback
|
||||
)
|
||||
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f" Deleted usergroup. id='{usergroup_id}'")
|
||||
else:
|
||||
mark_user_group_as_synced = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group", "mark_user_group_as_synced", noop_fallback
|
||||
)
|
||||
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
||||
|
||||
r.delete(rug.taskset_key)
|
||||
r.delete(rug.fence_key)
|
||||
|
||||
|
||||
@celery_app.task(name="monitor_vespa_sync", soft_time_limit=300)
|
||||
def monitor_vespa_sync() -> None:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
"""
|
||||
r = redis_pool.get_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
#
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
init_sqlalchemy_engine(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
init_sqlalchemy_engine(POSTGRES_CELERY_WORKER_APP_NAME)
|
||||
|
||||
# TODO(rkuo): this is singleton work that should be done on startup exactly once
|
||||
# if we run multiple workers, we'll need to centralize where this cleanup happens
|
||||
r = redis_pool.get_client()
|
||||
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"check-for-vespa-sync": {
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
"check-for-cc-pair-deletion": {
|
||||
"task": "check_for_cc_pair_deletion_task",
|
||||
# don't need to check too often, since we kick off a deletion initially
|
||||
# during the API call that actually marks the CC pair for deletion
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-for-prune": {
|
||||
"task": "check_for_prune_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
)
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"kombu-message-cleanup": {
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
}
|
||||
)
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"monitor-vespa-sync": {
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -1,299 +0,0 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
PREFIX = "base"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int):
|
||||
self._id: int = id
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
return f"{self.PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def fence_key(self) -> str:
|
||||
# example: documentset_fence_1
|
||||
return f"{self.FENCE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
# example: documentset_taskset_1
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[2])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(self._id)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(self._id)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
|
||||
@classmethod
|
||||
def get_taskset_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.TASKSET_PREFIX
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
"""Notice that this is intentionally reusing the same taskset for all
|
||||
connector syncs"""
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
used to implement task prioritization.
|
||||
This operation is not atomic."""
|
||||
total_length = 0
|
||||
for i in range(len(DanswerCeleryPriority)):
|
||||
queue_name = queue
|
||||
if i > 0:
|
||||
queue_name += CELERY_SEPARATOR
|
||||
queue_name += str(i)
|
||||
|
||||
length = r.llen(queue_name)
|
||||
total_length += cast(int, length)
|
||||
|
||||
return total_length
|
||||
@@ -1,9 +0,0 @@
|
||||
"""Entry point for running celery worker / celery beat."""
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
celery_app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.celery_app", "celery_app"
|
||||
)
|
||||
@@ -1,153 +1,23 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import get_latest_task_by_type
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_deletion_status(
|
||||
def get_deletion_status(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> TaskQueueState | None:
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
cleanup_task_name = name_cc_cleanup_task(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
)
|
||||
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
|
||||
task_state = get_latest_task(task_name=cleanup_task_name, db_session=db_session)
|
||||
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
|
||||
if not deletion_task:
|
||||
if not task_state:
|
||||
return None
|
||||
|
||||
return DeletionAttemptSnapshot(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
status=deletion_task.status,
|
||||
status=task_state.status,
|
||||
)
|
||||
|
||||
|
||||
def should_kick_off_deletion_of_cc_pair(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
if check_deletion_attempt_is_allowed(cc_pair, db_session):
|
||||
return False
|
||||
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if deletion_task and check_task_is_live_and_not_timed_out(
|
||||
deletion_task,
|
||||
db_session,
|
||||
# 1 hour timeout
|
||||
timeout=60 * 60,
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
if not connector.prune_freq:
|
||||
return False
|
||||
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=connector.id, credential_id=credential.id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
|
||||
if not last_pruning_task:
|
||||
time_since_initialization = current_db_time - connector.time_created
|
||||
if time_since_initialization.total_seconds() >= connector.prune_freq:
|
||||
return True
|
||||
return False
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
)
|
||||
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
return False
|
||||
|
||||
time_since_last_pruning = current_db_time - last_pruning_task.start_time
|
||||
return time_since_last_pruning.total_seconds() >= connector.prune_freq
|
||||
|
||||
|
||||
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
doc_batch_generator = None
|
||||
if isinstance(runnable_connector, IdConnector):
|
||||
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
if doc_batch_generator:
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
|
||||
CELERY_SEPARATOR = ":"
|
||||
|
||||
CELERY_PASSWORD_PART = ""
|
||||
if REDIS_PASSWORD:
|
||||
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
|
||||
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
broker_url = (
|
||||
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
|
||||
)
|
||||
|
||||
result_backend = (
|
||||
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
|
||||
)
|
||||
|
||||
# NOTE: prefetch 4 is significantly faster than prefetch 1
|
||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
||||
# can stall other tasks.
|
||||
worker_prefetch_multiplier = 4
|
||||
|
||||
broker_transport_options = {
|
||||
"priority_steps": list(range(len(DanswerCeleryPriority))),
|
||||
"sep": CELERY_SEPARATOR,
|
||||
"queue_order_strategy": "priority",
|
||||
}
|
||||
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
@@ -10,6 +10,8 @@ are multiple connector / credential pairs that have indexed it
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
@@ -22,8 +24,10 @@ from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document_connector_cnts
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.document_set import (
|
||||
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit,
|
||||
)
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
@@ -31,17 +35,13 @@ from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def delete_connector_credential_pair_batch(
|
||||
def _delete_connector_credential_pair_batch(
|
||||
document_ids: list[str],
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
@@ -78,37 +78,25 @@ def delete_connector_credential_pair_batch(
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt > 1
|
||||
]
|
||||
|
||||
# maps document id to list of document set names
|
||||
new_doc_sets_for_documents: dict[str, set[str]] = {
|
||||
document_id_and_document_set_names_tuple[0]: set(
|
||||
document_id_and_document_set_names_tuple[1]
|
||||
)
|
||||
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
)
|
||||
}
|
||||
|
||||
# determine future ACLs for documents in batch
|
||||
access_for_documents = get_access_for_documents(
|
||||
document_ids=document_ids_to_update,
|
||||
db_session=db_session,
|
||||
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=access,
|
||||
document_sets=new_doc_sets_for_documents[document_id],
|
||||
)
|
||||
for document_id, access in access_for_documents.items()
|
||||
]
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
# clean up Postgres
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
@@ -120,6 +108,48 @@ def delete_connector_credential_pair_batch(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cleanup_synced_entities(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> None:
|
||||
"""Updates the document sets associated with the connector / credential pair,
|
||||
then relies on the document set sync script to kick off Celery jobs which will
|
||||
sync these updates to Vespa.
|
||||
|
||||
Waits until the document sets are synced before returning."""
|
||||
logger.info(f"Cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'")
|
||||
document_sets_ids_to_sync = list(
|
||||
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# wait till all document sets are synced before continuing
|
||||
while True:
|
||||
all_synced = True
|
||||
document_sets = get_document_sets_by_ids(
|
||||
db_session=db_session, document_set_ids=document_sets_ids_to_sync
|
||||
)
|
||||
for document_set in document_sets:
|
||||
if not document_set.is_up_to_date:
|
||||
all_synced = False
|
||||
|
||||
if all_synced:
|
||||
break
|
||||
|
||||
# wait for 30 seconds before checking again
|
||||
db_session.commit() # end transaction
|
||||
logger.info(
|
||||
f"Document sets '{document_sets_ids_to_sync}' not synced yet, waiting 30s"
|
||||
)
|
||||
time.sleep(30)
|
||||
|
||||
logger.info(
|
||||
f"Finished cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'"
|
||||
)
|
||||
|
||||
|
||||
def delete_connector_credential_pair(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
@@ -139,7 +169,7 @@ def delete_connector_credential_pair(
|
||||
if not documents:
|
||||
break
|
||||
|
||||
delete_connector_credential_pair_batch(
|
||||
_delete_connector_credential_pair_batch(
|
||||
document_ids=[document.id for document in documents],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
@@ -147,32 +177,17 @@ def delete_connector_credential_pair(
|
||||
)
|
||||
num_docs_deleted += len(documents)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
# Clean up document sets / access information from Postgres
|
||||
# and sync these updates to Vespa
|
||||
# TODO: add user group cleanup with `fetch_versioned_implementation`
|
||||
cleanup_synced_entities(cc_pair, db_session)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
# clean up the rest of the related Postgres entities
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
@@ -184,11 +199,11 @@ def delete_connector_credential_pair(
|
||||
connector_id=connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
logger.info("Found no credentials left for connector, deleting connector")
|
||||
logger.debug("Found no credentials left for connector, deleting connector")
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
logger.notice(
|
||||
logger.info(
|
||||
"Successfully deleted connector_credential_pair with connector_id:"
|
||||
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
|
||||
)
|
||||
|
||||
@@ -41,12 +41,6 @@ def _initializer(
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _run_in_process(
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
_initializer(func, args, kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleJob:
|
||||
"""Drop in replacement for `dask.distributed.Future`"""
|
||||
@@ -111,15 +105,13 @@ class SimpleJobClient:
|
||||
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
||||
self._cleanup_completed_jobs()
|
||||
if len(self.jobs) >= self.n_workers:
|
||||
logger.debug(
|
||||
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
||||
)
|
||||
logger.debug("No available workers to run job")
|
||||
return None
|
||||
|
||||
job_id = self.job_id_counter
|
||||
self.job_id_counter += 1
|
||||
|
||||
process = Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||
process = Process(target=_initializer(func=func, args=args), daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
|
||||
@@ -6,22 +6,27 @@ from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.connector_deletion import (
|
||||
_delete_connector_credential_pair_batch,
|
||||
)
|
||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from danswer.background.indexing.tracer import DanswerTracer
|
||||
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from danswer.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from danswer.configs.app_configs import DISABLE_DOCUMENT_CLEANUP
|
||||
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from danswer.connectors.connector_runner import ConnectorRunner
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -32,19 +37,16 @@ from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
|
||||
|
||||
def _get_connector_runner(
|
||||
def _get_document_generator(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> ConnectorRunner:
|
||||
) -> tuple[GenerateDocumentsOutput, bool]:
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
|
||||
@@ -52,31 +54,47 @@ def _get_connector_runner(
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
task = attempt.connector_credential_pair.connector.input_type
|
||||
task = attempt.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector_credential_pair.connector.source,
|
||||
runnable_connector, new_credential_json = instantiate_connector(
|
||||
attempt.connector.source,
|
||||
task,
|
||||
attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
attempt.connector_credential_pair.credential,
|
||||
db_session,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential.credential_json,
|
||||
)
|
||||
if new_credential_json is not None:
|
||||
backend_update_credential_json(
|
||||
attempt.credential, new_credential_json, db_session
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
# since we failed to even instantiate the connector, we pause the CCPair since
|
||||
# it will never succeed
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector_credential_pair.connector.id,
|
||||
credential_id=attempt.connector_credential_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.PAUSED,
|
||||
)
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
raise e
|
||||
|
||||
return ConnectorRunner(
|
||||
connector=runnable_connector, time_range=(start_time, end_time)
|
||||
)
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
is_listing_complete = True
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
)
|
||||
|
||||
logger.info(f"Polling for updates between {start_time} and {end_time}")
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time.timestamp(), end=end_time.timestamp()
|
||||
)
|
||||
is_listing_complete = False
|
||||
|
||||
else:
|
||||
# Event types cannot be handled by a background type
|
||||
raise RuntimeError(f"Invalid task type: {task}")
|
||||
|
||||
return doc_batch_generator, is_listing_complete
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
@@ -90,62 +108,46 @@ def _run_indexing(
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
search_settings = index_attempt.search_settings
|
||||
index_name = search_settings.index_name
|
||||
db_embedding_model = index_attempt.embedding_model
|
||||
index_name = db_embedding_model.index_name
|
||||
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary = search_settings.status == IndexModelStatus.PRESENT
|
||||
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
|
||||
# Indexing is only done into one index at a time
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
embedding_model = DefaultIndexingEmbedder(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt.id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE),
|
||||
or (db_embedding_model.status == IndexModelStatus.FUTURE),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_cc_pair = index_attempt.connector_credential_pair
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
earliest_index_time = (
|
||||
db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0
|
||||
)
|
||||
|
||||
db_connector = index_attempt.connector
|
||||
db_credential = index_attempt.credential
|
||||
last_successful_index_time = (
|
||||
earliest_index_time
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
earliest_index=earliest_index_time,
|
||||
search_settings=index_attempt.search_settings,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
|
||||
tracer = DanswerTracer()
|
||||
tracer.start()
|
||||
tracer.snap()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
)
|
||||
|
||||
batch_num = 0
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
@@ -158,37 +160,29 @@ def _run_indexing(
|
||||
source_type=db_connector.source,
|
||||
)
|
||||
):
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
doc_batch_generator, is_listing_complete = _get_document_generator(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
)
|
||||
|
||||
try:
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
tracer_counter = 0
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.snap()
|
||||
for doc_batch in connector_runner.run():
|
||||
for doc_batch in doc_batch_generator:
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
db_session.refresh(db_connector)
|
||||
if (
|
||||
(
|
||||
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and search_settings.status != IndexModelStatus.FUTURE
|
||||
)
|
||||
# if it's deleting, we don't care if this is a secondary index
|
||||
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
|
||||
db_connector.disabled
|
||||
and db_embedding_model.status != IndexModelStatus.FUTURE
|
||||
):
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
@@ -198,30 +192,17 @@ def _run_indexing(
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError("Index Attempt was canceled")
|
||||
|
||||
batch_description = []
|
||||
for doc in doc_batch:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
for section in doc.sections:
|
||||
doc_size += len(section.text)
|
||||
|
||||
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Document size: doc='{doc.to_short_descriptor()}' "
|
||||
f"size={doc_size} "
|
||||
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
|
||||
)
|
||||
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
logger.debug(
|
||||
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
documents=doc_batch,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
),
|
||||
)
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
@@ -243,16 +224,38 @@ def _run_indexing(
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
tracer_counter += 1
|
||||
if (
|
||||
INDEXING_TRACER_INTERVAL > 0
|
||||
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
|
||||
):
|
||||
logger.debug(
|
||||
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
|
||||
if is_listing_complete and not DISABLE_DOCUMENT_CLEANUP:
|
||||
# clean up all documents from the index that have not been returned from the connector
|
||||
all_indexed_document_ids = {
|
||||
d.id
|
||||
for d in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
)
|
||||
tracer.snap()
|
||||
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
}
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids
|
||||
)
|
||||
logger.debug(
|
||||
f"Cleaning up {len(doc_ids_to_remove)} documents that are not contained in the newest connector state"
|
||||
)
|
||||
|
||||
# delete docs from cc-pair and receive the number of completely deleted docs in return
|
||||
_delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=len(doc_ids_to_remove),
|
||||
)
|
||||
|
||||
run_end_dt = window_end
|
||||
if is_primary:
|
||||
@@ -260,11 +263,12 @@ def _run_indexing(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
logger.info(
|
||||
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
@@ -276,7 +280,7 @@ def _run_indexing(
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or db_connector.disabled
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
@@ -288,74 +292,34 @@ def _run_indexing(
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
break
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(
|
||||
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
|
||||
)
|
||||
tracer.snap()
|
||||
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
tracer.stop()
|
||||
logger.debug("Memory tracer stopped.")
|
||||
|
||||
if (
|
||||
index_attempt_md.num_exceptions > 0
|
||||
and index_attempt_md.num_exceptions >= batch_num
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt,
|
||||
db_session,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector_credential_pair.connector.id,
|
||||
credential_id=index_attempt.connector_credential_pair.credential.id,
|
||||
)
|
||||
raise Exception(
|
||||
f"Connector failed - All batches exceptioned: batches={batch_num}"
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt, db_session)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"exceptions={index_attempt_md.num_exceptions} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.SUCCESS,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexed or refreshed {document_count} total documents for a total of {chunk_count} indexed chunks"
|
||||
)
|
||||
logger.info(
|
||||
f"Connector successfully finished, elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
|
||||
|
||||
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
|
||||
# make sure that the index attempt can't change in between checking the
|
||||
@@ -368,7 +332,6 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
if attempt is None:
|
||||
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
|
||||
|
||||
@@ -379,27 +342,29 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
)
|
||||
|
||||
# only commit once, to make sure this all happens in a single transaction
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
mark_attempt_in_progress__no_commit(attempt)
|
||||
is_primary = attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector.id,
|
||||
credential_id=attempt.credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
)
|
||||
else:
|
||||
db_session.commit()
|
||||
|
||||
return attempt
|
||||
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
|
||||
) -> None:
|
||||
def run_indexing_entrypoint(index_attempt_id: int) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
|
||||
try:
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# make sure that it is valid to run this indexing attempt + mark it
|
||||
@@ -407,19 +372,17 @@ def run_indexing_entrypoint(
|
||||
attempt = _prepare_index_attempt(db_session, index_attempt_id)
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
f"Running indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import tracemalloc
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DANSWER_TRACEMALLOC_FRAMES = 10
|
||||
|
||||
|
||||
class DanswerTracer:
|
||||
def __init__(self) -> None:
|
||||
self.snapshot_first: tracemalloc.Snapshot | None = None
|
||||
self.snapshot_prev: tracemalloc.Snapshot | None = None
|
||||
self.snapshot: tracemalloc.Snapshot | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
|
||||
|
||||
def stop(self) -> None:
|
||||
tracemalloc.stop()
|
||||
|
||||
def snap(self) -> None:
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
|
||||
snapshot = snapshot.filter_traces(
|
||||
(
|
||||
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
|
||||
tracemalloc.Filter(
|
||||
False, "<frozen importlib._bootstrap>"
|
||||
), # Exclude importlib
|
||||
tracemalloc.Filter(
|
||||
False, "<frozen importlib._bootstrap_external>"
|
||||
), # Exclude external importlib
|
||||
)
|
||||
)
|
||||
|
||||
if not self.snapshot_first:
|
||||
self.snapshot_first = snapshot
|
||||
|
||||
if self.snapshot:
|
||||
self.snapshot_prev = self.snapshot
|
||||
|
||||
self.snapshot = snapshot
|
||||
|
||||
def log_snapshot(self, numEntries: int) -> None:
|
||||
if not self.snapshot:
|
||||
return
|
||||
|
||||
stats = self.snapshot.statistics("traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.debug(f"Tracer snap: {s}")
|
||||
for line in s.traceback:
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
@staticmethod
|
||||
def log_diff(
|
||||
snap_current: tracemalloc.Snapshot,
|
||||
snap_previous: tracemalloc.Snapshot,
|
||||
numEntries: int,
|
||||
) -> None:
|
||||
stats = snap_current.compare_to(snap_previous, "traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.debug(f"Tracer diff: {s}")
|
||||
for line in s.traceback.format():
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
def log_previous_diff(self, numEntries: int) -> None:
|
||||
if not self.snapshot or not self.snapshot_prev:
|
||||
return
|
||||
|
||||
DanswerTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
|
||||
|
||||
def log_first_diff(self, numEntries: int) -> None:
|
||||
if not self.snapshot or not self.snapshot_first:
|
||||
return
|
||||
|
||||
DanswerTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)
|
||||
@@ -22,15 +22,6 @@ def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
return f"sync_doc_set_{document_set_id}"
|
||||
|
||||
|
||||
def name_cc_prune_task(
|
||||
connector_id: int | None = None, credential_id: int | None = None
|
||||
) -> str:
|
||||
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
if not connector_id or not credential_id:
|
||||
task_name = "prune_connector_credential_pair"
|
||||
return task_name
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
@@ -93,16 +84,9 @@ def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA
|
||||
kwargs_for_build_name = kwargs or {}
|
||||
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# register_task must come before fn = apply_async or else the task
|
||||
# might run mark_task_start (and crash) before the task row exists
|
||||
db_task = register_task(task_name, db_session)
|
||||
|
||||
# mark the task as started
|
||||
task = fn(args, kwargs, *other_args, **other_kwargs)
|
||||
|
||||
# we update the celery task id for diagnostic purposes
|
||||
# but it isn't currently used by any code
|
||||
db_task.task_id = task.id
|
||||
db_session.commit()
|
||||
register_task(task.id, task_name, db_session)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@@ -16,33 +16,27 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
@@ -59,68 +53,41 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
|
||||
|
||||
def _should_create_new_indexing(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
connector: Connector,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
model: EmbeddingModel,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
connector = cc_pair.connector
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
if model.status == IndexModelStatus.FUTURE and not last_index:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if not cc_pair.status.is_active() or connector.id == 0:
|
||||
# If the connector is disabled, don't index
|
||||
# NOTE: during an embedding model switch over, we ignore this
|
||||
# and index the disabled connectors as well (which is why this if
|
||||
# statement is below the first condition above)
|
||||
if connector.disabled:
|
||||
return False
|
||||
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
# Only one scheduled/ongoing job per connector at a time
|
||||
# this prevents cases where
|
||||
# (1) the "latest" index_attempt is scheduled so we show
|
||||
# that in the UI despite another index_attempt being in-progress
|
||||
# (2) multiple scheduled index_attempts at a time
|
||||
if (
|
||||
last_index.status == IndexingStatus.NOT_STARTED
|
||||
or last_index.status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
# Only one scheduled job per connector at a time
|
||||
# Can schedule another one if the current one is already running however
|
||||
# Because the currently running one will not be until the latest time
|
||||
# Note, this last index is for the given embedding model
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
@@ -128,20 +95,41 @@ def _should_create_new_indexing(
|
||||
return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
|
||||
|
||||
def _is_indexing_job_marked_as_finished(index_attempt: IndexAttempt | None) -> bool:
|
||||
if index_attempt is None:
|
||||
return False
|
||||
|
||||
return (
|
||||
index_attempt.status == IndexingStatus.FAILED
|
||||
or index_attempt.status == IndexingStatus.SUCCESS
|
||||
)
|
||||
|
||||
|
||||
def _mark_run_failed(
|
||||
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
) -> None:
|
||||
"""Marks the `index_attempt` row as failed + updates the `
|
||||
connector_credential_pair` to reflect that the run failed"""
|
||||
logger.warning(
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
|
||||
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
db_session=db_session,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
if (
|
||||
index_attempt.connector_id is not None
|
||||
and index_attempt.credential_id is not None
|
||||
and index_attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
):
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector_id,
|
||||
credential_id=index_attempt.credential_id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
)
|
||||
|
||||
|
||||
"""Main funcs"""
|
||||
@@ -154,7 +142,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
ongoing: set[tuple[int | None, int]] = set()
|
||||
ongoing: set[tuple[int | None, int | None, int]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
@@ -167,43 +155,52 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
continue
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.search_settings_id,
|
||||
attempt.connector_id,
|
||||
attempt.credential_id,
|
||||
attempt.embedding_model_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
embedding_models = [get_current_db_embedding_model(db_session)]
|
||||
secondary_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
if secondary_embedding_model is not None:
|
||||
embedding_models.append(secondary_embedding_model)
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
all_connectors = fetch_connectors(db_session)
|
||||
for connector in all_connectors:
|
||||
for association in connector.credentials:
|
||||
for model in embedding_models:
|
||||
credential = association.credential
|
||||
|
||||
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in all_connector_credential_pairs:
|
||||
for search_settings_instance in search_settings:
|
||||
# Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
if (cc_pair.id, search_settings_instance.id) in ongoing:
|
||||
continue
|
||||
# Check if there is an ongoing indexing attempt for this connector + credential pair
|
||||
if (connector.id, credential.id, model.id) in ongoing:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
last_attempt = get_last_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
connector=connector,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
create_index_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
|
||||
# CC-Pair will have the status that it should for the primary index
|
||||
# Will be re-sync-ed once the indices are swapped
|
||||
if model.status == IndexModelStatus.PRESENT:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
attempt_status=IndexingStatus.NOT_STARTED,
|
||||
)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
@@ -220,12 +217,10 @@ def cleanup_indexing_jobs(
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
if not job.done() and not _is_indexing_job_marked_as_finished(
|
||||
index_attempt
|
||||
):
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
@@ -297,28 +292,24 @@ def kickoff_indexing_jobs(
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
with Session(engine) as db_session:
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.search_settings)
|
||||
(attempt, attempt.embedding_model)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
if attempt.id not in existing_jobs
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
|
||||
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
|
||||
|
||||
if not new_indexing_attempts:
|
||||
return existing_jobs
|
||||
|
||||
indexing_attempt_count = 0
|
||||
|
||||
for attempt, search_settings in new_indexing_attempts:
|
||||
for attempt, embedding_model in new_indexing_attempts:
|
||||
use_secondary_index = (
|
||||
search_settings.status == IndexModelStatus.FUTURE
|
||||
if search_settings is not None
|
||||
embedding_model.status == IndexModelStatus.FUTURE
|
||||
if embedding_model is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector_credential_pair.connector is None:
|
||||
if attempt.connector is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
)
|
||||
@@ -327,7 +318,7 @@ def kickoff_indexing_jobs(
|
||||
attempt, db_session, failure_reason="Connector is null"
|
||||
)
|
||||
continue
|
||||
if attempt.connector_credential_pair.credential is None:
|
||||
if attempt.credential is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
)
|
||||
@@ -339,73 +330,39 @@ def kickoff_indexing_jobs(
|
||||
|
||||
if use_secondary_index:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
run_indexing_entrypoint, attempt.id, pure=False
|
||||
)
|
||||
else:
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
|
||||
|
||||
if run:
|
||||
if indexing_attempt_count == 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
|
||||
)
|
||||
|
||||
indexing_attempt_count += 1
|
||||
secondary_str = " (secondary index)" if use_secondary_index else ""
|
||||
secondary_str = "(secondary index) " if use_secondary_index else ""
|
||||
logger.info(
|
||||
f"Indexing dispatched{secondary_str}: "
|
||||
f"attempt_id={attempt.id} "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.credential_id}'"
|
||||
f"Kicked off {secondary_str}"
|
||||
f"indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
existing_jobs_copy[attempt.id] = run
|
||||
|
||||
if indexing_attempt_count > 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch results: "
|
||||
f"initial_pending={len(new_indexing_attempts)} "
|
||||
f"started={indexing_attempt_count} "
|
||||
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def update_loop(
|
||||
delay: int = 10,
|
||||
num_workers: int = NUM_INDEXING_WORKERS,
|
||||
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
) -> None:
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
@@ -420,7 +377,7 @@ def update_loop(
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
cluster_secondary = LocalCluster(
|
||||
n_workers=num_secondary_workers,
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
@@ -430,18 +387,23 @@ def update_loop(
|
||||
client_primary.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
|
||||
with Session(engine) as db_session:
|
||||
# Previous version did not always clean up cc-pairs well leaving some connectors undeleteable
|
||||
# This ensures that bad states get cleaned up
|
||||
mark_all_in_progress_cc_pairs_failed(db_session)
|
||||
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||
logger.info(f"Running update, current UTC time: {start_time_utc}")
|
||||
|
||||
if existing_jobs:
|
||||
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"Found existing indexing jobs: "
|
||||
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
||||
)
|
||||
@@ -464,10 +426,7 @@ def update_loop(
|
||||
|
||||
|
||||
def update__main() -> None:
|
||||
set_is_ee_based_on_env_variable()
|
||||
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
|
||||
|
||||
logger.notice("Starting indexing service")
|
||||
logger.info("Starting Indexing Loop")
|
||||
update_loop()
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -8,46 +9,53 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
|
||||
def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
document_id=inf_chunk.document_id,
|
||||
# This one is using the combined content of all the chunks of the section
|
||||
# In default settings, this is the same as just the content of base chunk
|
||||
content=inference_section.combined_content,
|
||||
blurb=inference_section.center_chunk.blurb,
|
||||
semantic_identifier=inference_section.center_chunk.semantic_identifier,
|
||||
source_type=inference_section.center_chunk.source_type,
|
||||
metadata=inference_section.center_chunk.metadata,
|
||||
updated_at=inference_section.center_chunk.updated_at,
|
||||
link=inference_section.center_chunk.source_links[0]
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
content=inf_chunk.combined_content,
|
||||
blurb=inf_chunk.blurb,
|
||||
semantic_identifier=inf_chunk.semantic_identifier,
|
||||
source_type=inf_chunk.source_type,
|
||||
metadata=inf_chunk.metadata,
|
||||
updated_at=inf_chunk.updated_at,
|
||||
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
|
||||
source_links=inf_chunk.source_links,
|
||||
)
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> dict[str, int]:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
for chunk in chunks:
|
||||
if chunk.document_id not in order_mapping:
|
||||
order_mapping[chunk.document_id] = current
|
||||
current += 1
|
||||
|
||||
return order_mapping
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
# Optional id at which we finish processing
|
||||
stop_at_message_id: int | None = None,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
|
||||
all_chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
prefetch_tool_calls=prefetch_tool_calls,
|
||||
)
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
@@ -63,12 +71,7 @@ def create_chat_chain(
|
||||
current_message: ChatMessage | None = root_message
|
||||
while current_message is not None:
|
||||
child_msg = current_message.latest_child_message
|
||||
|
||||
# Break if at the end of the chain
|
||||
# or have reached the `final_id` of the submitted message
|
||||
if not child_msg or (
|
||||
stop_at_message_id and current_message.id == stop_at_message_id
|
||||
):
|
||||
if not child_msg:
|
||||
break
|
||||
current_message = id_to_msg.get(child_msg)
|
||||
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
input_prompts:
|
||||
- id: -5
|
||||
prompt: "Elaborate"
|
||||
content: "Elaborate on the above, give me a more in depth explanation."
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -4
|
||||
prompt: "Reword"
|
||||
content: "Help me rewrite the following politely and concisely for professional communication:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -3
|
||||
prompt: "Email"
|
||||
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -2
|
||||
prompt: "Debug"
|
||||
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
@@ -1,20 +1,18 @@
|
||||
from typing import cast
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.db.chat import get_prompt_by_name
|
||||
from danswer.db.chat import upsert_persona
|
||||
from danswer.db.chat import upsert_prompt
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt as PromptDBModel
|
||||
from danswer.db.models import Tool as ToolDBModel
|
||||
from danswer.db.persona import get_prompt_by_name
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.persona import upsert_prompt
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
@@ -52,7 +50,7 @@ def load_personas_from_yaml(
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
doc_sets: list[DocumentSetDBModel] | None = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
]
|
||||
@@ -60,51 +58,27 @@ def load_personas_from_yaml(
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
doc_set_ids: list[int] | None = None
|
||||
if doc_sets:
|
||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||
else:
|
||||
doc_set_ids = None
|
||||
if not doc_sets:
|
||||
doc_sets = None
|
||||
|
||||
prompt_ids: list[int] | None = None
|
||||
prompt_set_names = persona["prompts"]
|
||||
if prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] = [
|
||||
if not prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] | None = None
|
||||
else:
|
||||
prompts = [
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
if not prompts:
|
||||
prompts = None
|
||||
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
.filter(ToolDBModel.name == "ImageGenerationTool")
|
||||
.first()
|
||||
)
|
||||
if image_gen_tool:
|
||||
tool_ids.append(image_gen_tool.id)
|
||||
|
||||
llm_model_provider_override = persona.get("llm_model_provider_override")
|
||||
llm_model_version_override = persona.get("llm_model_version_override")
|
||||
|
||||
# Set specific overrides for image generation persona
|
||||
if persona.get("image_generation"):
|
||||
llm_model_version_override = "gpt-4o"
|
||||
|
||||
existing_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.name == persona["name"])
|
||||
.first()
|
||||
)
|
||||
|
||||
upsert_persona(
|
||||
user=None,
|
||||
# Negative to not conflict with existing personas
|
||||
persona_id=(-1 * p_id) if p_id is not None else None,
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
@@ -114,52 +88,20 @@ def load_personas_from_yaml(
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
tool_ids=tool_ids,
|
||||
prompts=cast(list[PromptDBModel] | None, prompts),
|
||||
document_sets=doc_sets,
|
||||
default_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority"),
|
||||
is_visible=existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible"),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
|
||||
with open(input_prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_input_prompts = data.get("input_prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(prompt_yaml)
|
||||
load_personas_from_yaml(personas_yaml)
|
||||
load_input_prompts_from_yaml(input_prompts_yaml)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -10,7 +9,6 @@ from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
@@ -36,51 +34,19 @@ class QADocsResponse(RetrievalDocs):
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
initial_dict["applied_time_cutoff"] = (
|
||||
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
||||
)
|
||||
|
||||
return initial_dict
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
data["stop_reason"] = self.stop_reason.name
|
||||
return data
|
||||
|
||||
|
||||
# Second chunk of info for streaming QA
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
relevant_chunk_indices: list[int]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
relevant: bool
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class SectionRelevancePiece(RelevanceAnalysis):
|
||||
"""LLM analysis mapped to an Inference Section"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int # ID of the center chunk for a given inference section
|
||||
|
||||
|
||||
class DocumentRelevance(BaseModel):
|
||||
"""Contains all relevance information for a given search"""
|
||||
|
||||
relevance_summaries: dict[str, RelevanceAnalysis]
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
@@ -93,14 +59,8 @@ class CitationInfo(BaseModel):
|
||||
document_id: str
|
||||
|
||||
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
@@ -146,20 +106,13 @@ class ImageGenerationDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
class CustomToolResponse(BaseModel):
|
||||
response: ToolResultType
|
||||
tool_name: str
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ personas:
|
||||
# this is for DanswerBot to use when tagged in a non-configured channel
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Knowledge"
|
||||
name: "Danswer"
|
||||
description: >
|
||||
Assistant with access to documents from your Connected Sources.
|
||||
# Default Prompt objects attached to the persona, see prompts.yaml
|
||||
@@ -17,7 +17,7 @@ personas:
|
||||
num_chunks: 10
|
||||
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
|
||||
# if the chunk is useful or not towards the latest user query
|
||||
# This feature can be overriden for all personas via DISABLE_LLM_DOC_RELEVANCE env variable
|
||||
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
|
||||
llm_relevance_filter: true
|
||||
# Enable/Disable usage of the LLM to extract query time filters including source type and time range filters
|
||||
llm_filter_extraction: true
|
||||
@@ -37,15 +37,12 @@ personas:
|
||||
# - "Engineer Onboarding"
|
||||
# - "Benefits"
|
||||
document_sets: []
|
||||
icon_shape: 23013
|
||||
icon_color: "#6FB1FF"
|
||||
display_priority: 1
|
||||
is_visible: true
|
||||
|
||||
|
||||
- id: 1
|
||||
name: "General"
|
||||
name: "GPT"
|
||||
description: >
|
||||
Assistant with no access to documents. Chat with just the Large Language Model.
|
||||
Assistant with no access to documents. Chat with just the Language Model.
|
||||
prompts:
|
||||
- "OnlyLLM"
|
||||
num_chunks: 0
|
||||
@@ -53,10 +50,7 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 50910
|
||||
icon_color: "#FF6F6F"
|
||||
display_priority: 0
|
||||
is_visible: true
|
||||
|
||||
|
||||
- id: 2
|
||||
name: "Paraphrase"
|
||||
@@ -69,25 +63,3 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 45519
|
||||
icon_color: "#6FFF8D"
|
||||
display_priority: 2
|
||||
is_visible: false
|
||||
|
||||
|
||||
- id: 3
|
||||
name: "Art"
|
||||
description: >
|
||||
Assistant for generating images based on descriptions.
|
||||
prompts:
|
||||
- "ImageGeneration"
|
||||
num_chunks: 0
|
||||
llm_relevance_filter: false
|
||||
llm_filter_extraction: false
|
||||
recency_bias: "no_decay"
|
||||
document_sets: []
|
||||
icon_shape: 234124
|
||||
icon_color: "#9B59B6"
|
||||
image_generation: true
|
||||
display_priority: 3
|
||||
is_visible: true
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from functools import partial
|
||||
@@ -8,20 +7,15 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
@@ -29,16 +23,13 @@ from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_db_search_doc_by_id
|
||||
from danswer.db.chat import get_doc_query_identifiers_from_model
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
@@ -51,47 +42,25 @@ from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.factory import get_tool_cls
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -124,21 +93,14 @@ def _handle_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
selected_search_docs: list[DbSearchDoc] | None,
|
||||
dedupe_docs: bool = False,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
||||
response_sumary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
dropped_inds = None
|
||||
if not selected_search_docs:
|
||||
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
|
||||
|
||||
deduped_docs = top_docs
|
||||
if dedupe_docs:
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
|
||||
for top_doc in top_docs
|
||||
]
|
||||
else:
|
||||
reference_db_search_docs = selected_search_docs
|
||||
@@ -158,81 +120,35 @@ def _handle_search_tool_response_summary(
|
||||
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
dropped_inds,
|
||||
)
|
||||
|
||||
|
||||
def _handle_internet_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
||||
internet_search_response = cast(InternetSearchResponse, packet.response)
|
||||
server_search_docs = internet_search_response_to_search_docs(
|
||||
internet_search_response
|
||||
)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in server_search_docs
|
||||
]
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=internet_search_response.revised_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=SearchType.SEMANTIC,
|
||||
applied_source_filters=[],
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
)
|
||||
|
||||
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
|
||||
) -> ForceUseTool:
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
)
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
|
||||
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
# Currently, the internet search tool does not support query override
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override and tool_name == SearchTool._NAME
|
||||
else None
|
||||
)
|
||||
|
||||
if new_msg_req.file_descriptors:
|
||||
# If user has uploaded files they're using, don't run any of the search tools
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name)
|
||||
|
||||
should_force_search = any(
|
||||
[
|
||||
def _check_should_force_search(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
) -> ForceUseTool | None:
|
||||
if (
|
||||
new_msg_req.query_override
|
||||
or (
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
]
|
||||
)
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
)
|
||||
or new_msg_req.search_doc_ids
|
||||
):
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override
|
||||
else None
|
||||
)
|
||||
# if we are using selected docs, just put something here so the Tool doesn't need
|
||||
# to build its own args via an LLM call
|
||||
if new_msg_req.search_doc_ids:
|
||||
args = {"query": new_msg_req.message}
|
||||
|
||||
if should_force_search:
|
||||
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
|
||||
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
|
||||
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
|
||||
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool.name(),
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
@@ -243,8 +159,6 @@ ChatPacket = (
|
||||
| DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| MessageResponseIDInfo
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -260,21 +174,16 @@ def stream_chat_message_objects(
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
# user message (e.g. this can only be used for the chat-seeding flow).
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
# Currently surrounding context is not supported for chat
|
||||
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
||||
new_msg_req.chunks_above = 0
|
||||
new_msg_req.chunks_below = 0
|
||||
|
||||
"""
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
@@ -289,18 +198,7 @@ def stream_chat_message_objects(
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
alternate_assistant_id = new_msg_req.alternate_assistant_id
|
||||
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
persona = chat_session.persona
|
||||
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
if prompt_id is None and persona.prompts:
|
||||
@@ -312,28 +210,20 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
try:
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
llm = get_llm_for_persona(
|
||||
persona, new_msg_req.llm_override or chat_session.llm_override
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
|
||||
|
||||
llm_provider = llm.config.model_provider
|
||||
llm_model_name = llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
@@ -350,16 +240,7 @@ def stream_chat_message_objects(
|
||||
else:
|
||||
parent_message = root_message
|
||||
|
||||
user_message = None
|
||||
|
||||
if new_msg_req.regenerate:
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
stop_at_message_id=parent_id,
|
||||
chat_session_id=chat_session_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
elif not use_existing_user_message:
|
||||
if not use_existing_user_message:
|
||||
# Create new message at the right place in the tree and update the parent's child pointer
|
||||
# Don't commit yet until we verify the chat message chain
|
||||
user_message = create_new_chat_message(
|
||||
@@ -369,7 +250,10 @@ def stream_chat_message_objects(
|
||||
message=message_text,
|
||||
token_count=len(llm_tokenizer_encode_func(message_text)),
|
||||
message_type=MessageType.USER,
|
||||
files=None, # Need to attach later for optimization to only load files once in parallel
|
||||
files=[
|
||||
{"id": str(file_id), "type": ChatFileType.IMAGE}
|
||||
for file_id in new_msg_req.file_ids
|
||||
],
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
@@ -384,8 +268,8 @@ def stream_chat_message_objects(
|
||||
"Be sure to update the chat pointers before calling this."
|
||||
)
|
||||
|
||||
# NOTE: do not commit user message - it will be committed when the
|
||||
# assistant message is successfully generated
|
||||
# Save now to save the latest chat message
|
||||
db_session.commit()
|
||||
else:
|
||||
# re-create linear history of messages
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
@@ -398,36 +282,14 @@ def stream_chat_message_objects(
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
# leads to worst search quality
|
||||
if not history_msgs:
|
||||
new_msg_req.query_override = (
|
||||
new_msg_req.query_override or new_msg_req.message
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
)
|
||||
files = load_all_chat_files(history_msgs, new_msg_req.file_ids, db_session)
|
||||
latest_query_files = [
|
||||
file
|
||||
for file in files
|
||||
if file.file_id in [f["id"] for f in new_msg_req.file_descriptors]
|
||||
file for file in files if file.file_id in new_msg_req.file_ids
|
||||
]
|
||||
|
||||
if user_message:
|
||||
attach_files_to_chat_message(
|
||||
chat_message=user_message,
|
||||
files=[
|
||||
new_file.to_file_descriptor() for new_file in latest_query_files
|
||||
],
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
selected_db_search_docs = None
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
selected_llm_docs: list[LlmDoc] | None = None
|
||||
if reference_doc_ids:
|
||||
identifier_tuples = get_doc_query_identifiers_from_model(
|
||||
search_doc_ids=reference_doc_ids,
|
||||
@@ -437,8 +299,8 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# Generates full documents currently
|
||||
# May extend to use sections instead in the future
|
||||
selected_sections = inference_sections_from_ids(
|
||||
# May extend to include chunk ranges
|
||||
selected_llm_docs = inference_documents_from_ids(
|
||||
doc_identifiers=identifier_tuples,
|
||||
document_index=document_index,
|
||||
)
|
||||
@@ -465,23 +327,9 @@ def stream_chat_message_objects(
|
||||
else default_num_chunks
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
use_sections=new_msg_req.chunks_above > 0
|
||||
or new_msg_req.chunks_below > 0,
|
||||
)
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id if user_message else None,
|
||||
reserved_assistant_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
overridden_model = (
|
||||
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
|
||||
)
|
||||
|
||||
# Cannot determine these without the LLM step or breaking out early
|
||||
partial_response = partial(
|
||||
@@ -489,134 +337,83 @@ def stream_chat_message_objects(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=final_msg,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
# rephrased_query=,
|
||||
# token_count=,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
# error=,
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
if not final_msg.prompt:
|
||||
raise RuntimeError("No Prompt found")
|
||||
|
||||
prompt_config = (
|
||||
PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
)
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
prompt_config = PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
)
|
||||
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
img_generation_llm_config = llm.config
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name=openai_provider.default_model_name,
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(api_key=bing_api_key)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(
|
||||
db_tool_model.openapi_schema
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
persona_tool_classes = [
|
||||
get_tool_cls(tool, db_session) for tool in persona.tools
|
||||
]
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
tools, llm_tokenizer
|
||||
persona_tool_classes
|
||||
)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm_provider, llm_model_name
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
|
||||
# NOTE: for now, only support SearchTool and ImageGenerationTool
|
||||
# in the future, will support arbitrary user-defined tools
|
||||
search_tool: SearchTool | None = None
|
||||
tools: list[Tool] = []
|
||||
for tool_cls in persona_tool_classes:
|
||||
if tool_cls.__name__ == SearchTool.__name__:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm_config=llm.config,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_docs=selected_llm_docs,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
)
|
||||
tools.append(search_tool)
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
dalle_key = None
|
||||
if llm and llm.config.api_key and llm.config.model_provider == "openai":
|
||||
dalle_key = llm.config.api_key
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
dalle_key = openai_provider.api_key
|
||||
tools.append(ImageGenerationTool(api_key=dalle_key))
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=AnswerStyleConfig(
|
||||
@@ -628,69 +425,34 @@ def stream_chat_message_objects(
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
or get_main_llm_from_tuple(
|
||||
get_llms_for_persona(
|
||||
persona=persona,
|
||||
llm_override=(
|
||||
new_msg_req.llm_override or chat_session.llm_override
|
||||
),
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
or get_llm_for_persona(
|
||||
persona, new_msg_req.llm_override or chat_session.llm_override
|
||||
)
|
||||
),
|
||||
message_history=[
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
],
|
||||
tools=tools,
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
force_use_tool=_check_should_force_search(new_msg_req),
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
qa_docs_response = None
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
packet, db_session, selected_db_search_docs
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=llm_indices
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=packet.response
|
||||
)
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
@@ -706,39 +468,16 @@ def stream_chat_message_objects(
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(f"Failed to process chat message: {error_msg}")
|
||||
logger.exception(e)
|
||||
|
||||
stack_trace = traceback.format_exc()
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
db_session.rollback()
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
# This will be the issue 99% of the time
|
||||
yield StreamingError(error="LLM failed to respond, have you set your API key?")
|
||||
return
|
||||
|
||||
# Post-LLM answer processing
|
||||
@@ -751,13 +490,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
@@ -767,29 +500,15 @@ def stream_chat_message_objects(
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=db_citations,
|
||||
error=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else [],
|
||||
)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(error_msg)
|
||||
logger.exception(e)
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
yield StreamingError(error="Failed to parse LLM output")
|
||||
@@ -800,8 +519,6 @@ def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as db_session:
|
||||
objects = stream_chat_message_objects(
|
||||
@@ -809,8 +526,6 @@ def stream_chat_message(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
yield get_json_line(obj.dict())
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user