mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 00:35:46 +00:00
Compare commits
71 Commits
debug-test
...
freshdesk
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0a21b74d4 | ||
|
|
5c521a7916 | ||
|
|
28e65669b4 | ||
|
|
493c3d7314 | ||
|
|
b04e9e9b67 | ||
|
|
3755e575a5 | ||
|
|
63655cfbed | ||
|
|
7f788e4b1e | ||
|
|
1362d4b583 | ||
|
|
4f47004d47 | ||
|
|
3fdd233e84 | ||
|
|
0c54d9d57d | ||
|
|
c2088602e1 | ||
|
|
b3c367d09c | ||
|
|
457d32fef0 | ||
|
|
af187c6cfe | ||
|
|
a0235b7b7b | ||
|
|
a30de693cb | ||
|
|
07aeea69e7 | ||
|
|
bd40328a73 | ||
|
|
b8232e0681 | ||
|
|
fffb9c155a | ||
|
|
f513c5bbed | ||
|
|
9a4e51a18e | ||
|
|
2f2fc08553 | ||
|
|
c68c6fdc44 | ||
|
|
834c76e30a | ||
|
|
ec02665ffa | ||
|
|
3fa1b18306 | ||
|
|
c9bdf4c443 | ||
|
|
e229d27734 | ||
|
|
140c5b3957 | ||
|
|
3e511497d2 | ||
|
|
b0056907fb | ||
|
|
728a41a35a | ||
|
|
ef8dda2d47 | ||
|
|
15283b3140 | ||
|
|
e159b2e947 | ||
|
|
9155800fab | ||
|
|
a392ef0541 | ||
|
|
5679f0af61 | ||
|
|
ff8db71cb5 | ||
|
|
1cff2b82fd | ||
|
|
50dd3c8beb | ||
|
|
66a459234d | ||
|
|
19e57474dc | ||
|
|
f9638f2ea5 | ||
|
|
fbf51b70d0 | ||
|
|
b97cc01bb2 | ||
|
|
6d48fd5d99 | ||
|
|
1f61447b4b | ||
|
|
deee2b3513 | ||
|
|
b73d66c84a | ||
|
|
c5a61f4820 | ||
|
|
ea4a3cbf86 | ||
|
|
166514cedf | ||
|
|
be50ae1e71 | ||
|
|
f89504ec53 | ||
|
|
6b3213b1e4 | ||
|
|
48577bf0e4 | ||
|
|
c59d1ff0a5 | ||
|
|
ba38dec592 | ||
|
|
f5adc3063e | ||
|
|
8cfe80c53a | ||
|
|
487250320b | ||
|
|
c8d13922a9 | ||
|
|
cb75449cec | ||
|
|
b66514cd21 | ||
|
|
77650c9ee3 | ||
|
|
316b6b99ea | ||
|
|
34c2aa0860 |
@@ -7,16 +7,17 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# TODO: make this a matrix build like the web containers
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -31,7 +32,7 @@ jobs:
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
@@ -41,12 +42,20 @@ jobs:
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.REGISTRY_IMAGE }}:latest
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -5,14 +5,18 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-model-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -31,13 +35,21 @@ jobs:
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
danswer/danswer-model-server:${{ github.ref_name }}
|
||||
danswer/danswer-model-server:latest
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
@@ -7,7 +7,8 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
@@ -35,7 +36,7 @@ jobs:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -112,8 +113,16 @@ jobs:
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
7
.github/workflows/docker-tag-latest.yml
vendored
7
.github/workflows/docker-tag-latest.yml
vendored
@@ -1,3 +1,6 @@
|
||||
# This workflow is set up to be manually triggered via the GitHub Action tab.
|
||||
# Given a version, it will tag those backend and webserver images as "latest".
|
||||
|
||||
name: Tag Latest Version
|
||||
|
||||
on:
|
||||
@@ -9,7 +12,9 @@ on:
|
||||
|
||||
jobs:
|
||||
tag:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
name: Run Integration Tests
|
||||
name: Run Integration Tests v2
|
||||
concurrency:
|
||||
group: Run-Integration-Tests-${{ github.head_ref }}
|
||||
group: Run-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
runs-on: Amd64
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -27,25 +31,35 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# NOTE: we don't need to build the Web Docker image since it's not used
|
||||
# during the IT for now. We have a separate action to verify it builds
|
||||
# succesfully
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull danswer/danswer-web-server:latest
|
||||
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it
|
||||
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:test
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-backend:it
|
||||
cache-from: type=registry,ref=danswer/danswer-backend:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/danswer-backend:it,mode=max
|
||||
type=inline
|
||||
tags: danswer/danswer-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
@@ -53,11 +67,11 @@ jobs:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-model-server:it
|
||||
cache-from: type=registry,ref=danswer/danswer-model-server:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/danswer-model-server:it,mode=max
|
||||
type=inline
|
||||
tags: danswer/danswer-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
@@ -65,11 +79,11 @@ jobs:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: danswer/integration-test-runner:it
|
||||
cache-from: type=registry,ref=danswer/integration-test-runner:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/integration-test-runner:it,mode=max
|
||||
type=inline
|
||||
tags: danswer/danswer-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -78,7 +92,7 @@ jobs:
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=it \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker
|
||||
|
||||
@@ -120,6 +134,7 @@ jobs:
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
@@ -128,7 +143,9 @@ jobs:
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
danswer/integration-test-runner:it
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
@@ -12,7 +12,8 @@ on:
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
runs-on: Amd64
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
|
||||
7
.github/workflows/pr-python-checks.yml
vendored
7
.github/workflows/pr-python-checks.yml
vendored
@@ -3,11 +3,14 @@ name: Python Checks
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
mypy-check:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
@@ -15,10 +15,14 @@ env:
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
# Jira
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
58
.github/workflows/pr-python-model-tests.yml
vendored
Normal file
58
.github/workflows/pr-python-model-tests.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
env:
|
||||
# Bedrock
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
7
.github/workflows/pr-python-tests.yml
vendored
7
.github/workflows/pr-python-tests.yml
vendored
@@ -3,11 +3,14 @@ name: Python Unit Tests
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
backend-check:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
5
.github/workflows/pr-quality-checks.yml
vendored
5
.github/workflows/pr-quality-checks.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: Quality Checks PR
|
||||
concurrency:
|
||||
group: Quality-Checks-PR-${{ github.head_ref }}
|
||||
group: Quality-Checks-PR-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
@@ -9,7 +9,8 @@ on:
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
54
.github/workflows/tag-nightly.yml
vendored
Normal file
54
.github/workflows/tag-nightly.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
name: Nightly Tag Push
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 10 * * *' # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
|
||||
|
||||
permissions:
|
||||
contents: write # Allows pushing tags to the repository
|
||||
|
||||
jobs:
|
||||
create-and-push-tag:
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
|
||||
- name: Check for existing nightly tag
|
||||
id: check_tag
|
||||
run: |
|
||||
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
|
||||
echo "A tag starting with 'nightly-latest' already exists on HEAD."
|
||||
echo "tag_exists=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No tag starting with 'nightly-latest' exists on HEAD."
|
||||
echo "tag_exists=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# don't tag again if HEAD already has a nightly-latest tag on it
|
||||
- name: Create Nightly Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
env:
|
||||
DATE: ${{ github.run_id }}
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
echo "Creating tag: $TAG_NAME"
|
||||
git tag $TAG_NAME
|
||||
|
||||
- name: Push Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
git push origin $TAG_NAME
|
||||
|
||||
1
.prettierignore
Normal file
1
.prettierignore
Normal file
@@ -0,0 +1 @@
|
||||
backend/tests/integration/tests/pruning/website
|
||||
@@ -22,7 +22,7 @@ Your input is vital to making sure that Danswer moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
And always feel free to message us (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ) /
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
|
||||
|
||||
|
||||
@@ -9,9 +9,9 @@ from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
@@ -21,16 +21,26 @@ if config.config_file_name is not None and config.attributes.get(
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# Add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
def get_schema_options() -> tuple[str, bool]:
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
for pair in arg.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
x_args[key] = value
|
||||
|
||||
schema_name = x_args.get("schema", "public")
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
return schema_name, create_schema
|
||||
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
@@ -54,17 +64,20 @@ def run_migrations_offline() -> None:
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = build_connection_string()
|
||||
schema, _ = get_schema_options()
|
||||
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
@@ -72,22 +85,28 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
schema, create_schema = get_schema_options()
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
connection.execute(text(f'SET search_path TO "{schema}"'))
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
) # type: ignore
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
|
||||
"""Run migrations in 'online' mode."""
|
||||
connectable = create_async_engine(
|
||||
build_connection_string(),
|
||||
poolclass=pool.NullPool,
|
||||
@@ -101,7 +120,6 @@ async def run_async_migrations() -> None:
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""fix_user__external_user_group_id_fk
|
||||
|
||||
Revision ID: 46b7a812670f
|
||||
Revises: f32615f71aeb
|
||||
Create Date: 2024-09-23 12:58:03.894038
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "46b7a812670f"
|
||||
down_revision = "f32615f71aeb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing primary key
|
||||
op.drop_constraint(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
type_="primary",
|
||||
)
|
||||
|
||||
# Add the new composite primary key
|
||||
op.create_primary_key(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
["user_id", "external_user_group_id", "cc_pair_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the composite primary key
|
||||
op.drop_constraint(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
type_="primary",
|
||||
)
|
||||
# Delete all entries from the table
|
||||
op.execute("DELETE FROM user__external_user_group_id")
|
||||
|
||||
# Recreate the original primary key on user_id
|
||||
op.create_primary_key(
|
||||
"user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id"]
|
||||
)
|
||||
@@ -9,7 +9,7 @@ import json
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "703313b75876"
|
||||
@@ -54,9 +54,7 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
try:
|
||||
settings_json = cast(
|
||||
str, get_dynamic_config_store().load("token_budget_settings")
|
||||
)
|
||||
settings_json = cast(str, get_kv_store().load("token_budget_settings"))
|
||||
settings = json.loads(settings_json)
|
||||
|
||||
is_enabled = settings.get("enable_token_budget", False)
|
||||
@@ -71,7 +69,7 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
# Delete the dynamic config
|
||||
get_dynamic_config_store().delete("token_budget_settings")
|
||||
get_kv_store().delete("token_budget_settings")
|
||||
|
||||
except Exception:
|
||||
# Ignore if the dynamic config is not found
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_USER_STORE_KEY
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
|
||||
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_dynamic_config_store()
|
||||
store = get_kv_store()
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
|
||||
def write_invited_users(emails: list[str]) -> int:
|
||||
store = get_dynamic_config_store()
|
||||
store = get_kv_store()
|
||||
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
@@ -4,29 +4,29 @@ from typing import cast
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
|
||||
from danswer.dynamic_configs.store import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.store import DynamicConfigStore
|
||||
from danswer.key_value_store.store import KeyValueStore
|
||||
from danswer.key_value_store.store import KvKeyNotFoundError
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.server.manage.models import UserPreferences
|
||||
|
||||
|
||||
def set_no_auth_user_preferences(
|
||||
store: DynamicConfigStore, preferences: UserPreferences
|
||||
store: KeyValueStore, preferences: UserPreferences
|
||||
) -> None:
|
||||
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
|
||||
|
||||
|
||||
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
|
||||
def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
try:
|
||||
preferences_data = cast(
|
||||
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
return UserPreferences(chosen_assistants=None, default_model=None)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
return UserInfo(
|
||||
id="__no_auth_user__",
|
||||
email="anonymous@danswer.ai",
|
||||
|
||||
@@ -342,7 +342,6 @@ def get_database_strategy(
|
||||
strategy = DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -24,12 +25,12 @@ from danswer.db.models import TaskQueueState
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import get_latest_task_by_type
|
||||
from danswer.redis.redis_pool import RedisPool
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
redis_pool = RedisPool()
|
||||
|
||||
|
||||
def _get_deletion_status(
|
||||
@@ -46,7 +47,7 @@ def _get_deletion_status(
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
r = redis_pool.get_client()
|
||||
r = get_redis_client()
|
||||
if not r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
@@ -69,6 +70,30 @@ def get_deletion_attempt_snapshot(
|
||||
)
|
||||
|
||||
|
||||
def skip_cc_pair_pruning_by_task(
|
||||
pruning_task: TaskQueueState | None, db_session: Session
|
||||
) -> bool:
|
||||
"""task should be the latest prune task for this cc_pair"""
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
# if only one prune is allowed at any time, then check to see if any prune
|
||||
# is active
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
)
|
||||
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
return True
|
||||
|
||||
if pruning_task and check_task_is_live_and_not_timed_out(pruning_task, db_session):
|
||||
# if the last task is live right now, we shouldn't start a new one
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
@@ -79,31 +104,26 @@ def should_prune_cc_pair(
|
||||
connector_id=connector.id, credential_id=credential.id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
|
||||
if skip_cc_pair_pruning_by_task(last_pruning_task, db_session):
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
|
||||
if not last_pruning_task:
|
||||
# If the connector has never been pruned, then compare vs when the connector
|
||||
# was created
|
||||
time_since_initialization = current_db_time - connector.time_created
|
||||
if time_since_initialization.total_seconds() >= connector.prune_freq:
|
||||
return True
|
||||
return False
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
)
|
||||
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
# if the last prune task hasn't started, we shouldn't start a new one
|
||||
return False
|
||||
|
||||
# if the last prune task has a start time, then compare against it to determine
|
||||
# if we should start
|
||||
time_since_last_pruning = current_db_time - last_pruning_task.start_time
|
||||
return time_since_last_pruning.total_seconds() >= connector.prune_freq
|
||||
|
||||
@@ -141,3 +161,30 @@ def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> se
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
|
||||
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
||||
"""Checks to see if we're listening to the named queue"""
|
||||
|
||||
# how to get a list of queues this worker is listening to
|
||||
# https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime
|
||||
queue_names = list(worker.app.amqp.queues.consume_from.keys())
|
||||
for queue_name in queue_names:
|
||||
if queue_name == name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_is_worker_primary(worker: Any) -> bool:
|
||||
"""There are multiple approaches that could be taken, but the way we do it is to
|
||||
check the hostname set for the celery worker, either in celeryconfig.py or on the
|
||||
command line."""
|
||||
hostname = worker.hostname
|
||||
if hostname.startswith("light"):
|
||||
return False
|
||||
|
||||
if hostname.startswith("heavy"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
from danswer.configs.app_configs import CELERY_BROKER_POOL_LIMIT
|
||||
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
|
||||
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
@@ -9,6 +11,7 @@ from danswer.configs.app_configs import REDIS_SSL
|
||||
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
|
||||
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
CELERY_SEPARATOR = ":"
|
||||
|
||||
@@ -36,12 +39,30 @@ result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PO
|
||||
# can stall other tasks.
|
||||
worker_prefetch_multiplier = 4
|
||||
|
||||
broker_connection_retry_on_startup = True
|
||||
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
|
||||
|
||||
# redis broker settings
|
||||
# https://docs.celeryq.dev/projects/kombu/en/stable/reference/kombu.transport.redis.html
|
||||
broker_transport_options = {
|
||||
"priority_steps": list(range(len(DanswerCeleryPriority))),
|
||||
"sep": CELERY_SEPARATOR,
|
||||
"queue_order_strategy": "priority",
|
||||
"retry_on_timeout": True,
|
||||
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
|
||||
"socket_keepalive": True,
|
||||
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
}
|
||||
|
||||
# redis backend settings
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
|
||||
# there doesn't appear to be a way to set socket_keepalive_options on the redis result backend
|
||||
redis_socket_keepalive = True
|
||||
redis_retry_on_timeout = True
|
||||
redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
|
||||
|
||||
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.utils.log import get_task_logger
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def check_for_connector_deletion_task() -> None:
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair, db_session, r, lock_beat
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
Returns None if no syncing is required.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
# we need to refresh the state of the object inside the fence
|
||||
# to avoid a race condition with db.commit/fence deletion
|
||||
# at the end of this taskset
|
||||
try:
|
||||
db_session.refresh(cc_pair)
|
||||
except ObjectDeletedError:
|
||||
return None
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
last_indexing = get_last_attempt(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
search_settings_id=search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if last_indexing:
|
||||
if (
|
||||
last_indexing.status == IndexingStatus.IN_PROGRESS
|
||||
or last_indexing.status == IndexingStatus.NOT_STARTED
|
||||
):
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcd.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
140
backend/danswer/background/celery/tasks/periodic/tasks.py
Normal file
140
backend/danswer/background/celery/tasks/periodic/tasks.py
Normal file
@@ -0,0 +1,140 @@
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery.contrib.abortable import AbortableTask # type: ignore
|
||||
from celery.exceptions import TaskRevokedError
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="kombu_message_cleanup_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
|
||||
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
|
||||
|
||||
ctx = {}
|
||||
ctx["last_processed_id"] = 0
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
|
||||
).scalar()
|
||||
if not result:
|
||||
return 0
|
||||
|
||||
while True:
|
||||
if self.is_aborted():
|
||||
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
|
||||
|
||||
b = kombu_message_cleanup_task_helper(ctx, db_session)
|
||||
if not b:
|
||||
break
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if ctx["deleted"] > 0:
|
||||
task_logger.info(
|
||||
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
|
||||
)
|
||||
|
||||
return ctx["deleted"]
|
||||
|
||||
|
||||
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
|
||||
"""
|
||||
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
|
||||
|
||||
This function retrieves messages from the `kombu_message` table that are no longer visible and
|
||||
older than a specified interval. It checks if the corresponding task_id exists in the
|
||||
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
|
||||
|
||||
Args:
|
||||
ctx (dict): A context dictionary containing configuration parameters such as:
|
||||
- 'cleanup_age' (int): The age in days after which messages are considered old.
|
||||
- 'page_limit' (int): The maximum number of messages to process in one batch.
|
||||
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
|
||||
- 'deleted' (int): A counter to track the number of deleted messages.
|
||||
db_session (Session): The SQLAlchemy database session for executing queries.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if there are more rows to process, False if not.
|
||||
"""
|
||||
|
||||
inspector = inspect(db_session.bind)
|
||||
if not inspector:
|
||||
return False
|
||||
|
||||
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
|
||||
# We can fail silently.
|
||||
if not inspector.has_table("kombu_message"):
|
||||
return False
|
||||
|
||||
query = text(
|
||||
"""
|
||||
SELECT id, timestamp, payload
|
||||
FROM kombu_message WHERE visible = 'false'
|
||||
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
|
||||
AND id > :last_processed_id
|
||||
ORDER BY id
|
||||
LIMIT :page_limit
|
||||
"""
|
||||
)
|
||||
kombu_messages = db_session.execute(
|
||||
query,
|
||||
{
|
||||
"interval_days": f"{ctx['cleanup_age']} days",
|
||||
"page_limit": ctx["page_limit"],
|
||||
"last_processed_id": ctx["last_processed_id"],
|
||||
},
|
||||
).fetchall()
|
||||
|
||||
if len(kombu_messages) == 0:
|
||||
return False
|
||||
|
||||
for msg in kombu_messages:
|
||||
payload = json.loads(msg[2])
|
||||
task_id = payload["headers"]["id"]
|
||||
|
||||
# Check if task_id exists in celery_taskmeta
|
||||
task_exists = db_session.execute(
|
||||
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
|
||||
{"task_id": task_id},
|
||||
).fetchone()
|
||||
|
||||
# If task_id does not exist, delete the message
|
||||
if not task_exists:
|
||||
result = db_session.execute(
|
||||
text("DELETE FROM kombu_message WHERE id = :message_id"),
|
||||
{"message_id": msg[0]},
|
||||
)
|
||||
if result.rowcount > 0: # type: ignore
|
||||
ctx["deleted"] += 1
|
||||
|
||||
ctx["last_processed_id"] = msg[0]
|
||||
|
||||
return True
|
||||
120
backend/danswer/background/celery/tasks/pruning/tasks.py
Normal file
120
backend/danswer/background/celery/tasks/pruning/tasks.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from celery import shared_task
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.celery_utils import should_prune_cc_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_prune_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_prune_task() -> None:
|
||||
"""Runs periodically to check if any prune tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
|
||||
for cc_pair in all_cc_pairs:
|
||||
if should_prune_cc_pair(
|
||||
connector=cc_pair.connector,
|
||||
credential=cc_pair.credential,
|
||||
db_session=db_session,
|
||||
):
|
||||
task_logger.info(f"Pruning the {cc_pair.connector.name} connector")
|
||||
|
||||
prune_documents_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_cc_prune_task)
|
||||
@celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT)
|
||||
def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"ccpair not found for {connector_id} {credential_id}"
|
||||
)
|
||||
return
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector
|
||||
)
|
||||
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
if len(doc_ids_to_remove) == 0:
|
||||
task_logger.info(
|
||||
f"No docs to prune from {cc_pair.connector.source} connector"
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
|
||||
)
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning for connector id {connector_id}."
|
||||
)
|
||||
raise e
|
||||
525
backend/danswer/background/celery/tasks/vespa/tasks.py
Normal file
525
backend/danswer/background/celery/tasks/vespa/tasks.py
Normal file
@@ -0,0 +1,525 @@
|
||||
import traceback
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.utils.log import get_task_logger
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import add_deletion_failure_message
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import count_documents_by_needs_sync
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def check_for_vespa_sync_task() -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
|
||||
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
try_generate_document_set_sync_tasks(
|
||||
document_set, db_session, r, lock_beat
|
||||
)
|
||||
|
||||
# check if any user groups are not synced
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_groups"
|
||||
)
|
||||
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
for usergroup in user_groups:
|
||||
try_generate_user_group_sync_tasks(
|
||||
usergroup, db_session, r, lock_beat
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
# Always exceptions on the MIT version, which is expected
|
||||
pass
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
return None
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
stale_doc_count = count_documents_by_needs_sync(db_session)
|
||||
if stale_doc_count == 0:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||
)
|
||||
|
||||
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
|
||||
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
total_tasks_generated = 0
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
|
||||
if tasks_generated is None:
|
||||
continue
|
||||
|
||||
if tasks_generated == 0:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for single cc_pair. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
total_tasks_generated += tasks_generated
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
|
||||
return total_tasks_generated
|
||||
|
||||
|
||||
def try_generate_document_set_sync_tasks(
|
||||
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(document_set.id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if r.exists(rds.fence_key):
|
||||
return None
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
db_session.refresh(document_set)
|
||||
if document_set.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rds.taskset_key)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks finished. "
|
||||
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rds.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def try_generate_user_group_sync_tasks(
|
||||
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(usergroup.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
db_session.refresh(usergroup)
|
||||
if usergroup.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rug.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||
)
|
||||
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks finished. "
|
||||
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rug.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def monitor_connector_taskset(r: Redis) -> None:
|
||||
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
|
||||
task_logger.info(
|
||||
f"Stale document sync progress: remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count == 0:
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
if document_set_id is None:
|
||||
task_logger.warning("could not parse document set id from {key}")
|
||||
return
|
||||
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
fence_value = r.get(rds.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if document_set:
|
||||
if not document_set.connector_credential_pairs:
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
|
||||
r.delete(rds.taskset_key)
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id is None:
|
||||
task_logger.warning("could not parse document set id from {key}")
|
||||
return
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcd.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
return
|
||||
|
||||
try:
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair.id, error_message)
|
||||
task_logger.exception(
|
||||
f"Failed to run connector_deletion. "
|
||||
f"connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Successfully deleted connector_credential_pair with connector_id: '{cc_pair.connector_id}' "
|
||||
f"and credential_id: '{cc_pair.credential_id}'. "
|
||||
f"Deleted {initial_count} docs."
|
||||
)
|
||||
|
||||
r.delete(rcd.taskset_key)
|
||||
r.delete(rcd.fence_key)
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300)
|
||||
def monitor_vespa_sync() -> None:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
"""
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
monitor_connector_deletion_taskset(key_bytes, r)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# document set sync
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
# User group sync
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
update_request = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
document_index.update(update_requests=[update_request])
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
@@ -10,22 +10,38 @@ are multiple connector / credential pairs that have indexed it
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_connector_count
|
||||
from danswer.db.document import get_document_connector_counts
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
@@ -108,3 +124,88 @@ def delete_connector_credential_pair_batch(
|
||||
),
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def document_by_cc_pair_cleanup_task(
|
||||
self: Task, document_id: str, connector_id: int, credential_id: int
|
||||
) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
# count == 1 means this is the only remaining cc_pair reference to the doc
|
||||
# delete it from vespa and the db
|
||||
document_index.delete_single(doc_id=document_id)
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
)
|
||||
elif count > 1:
|
||||
# count > 1 means the document still has cc_pair references
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# the below functions do not include cc_pairs being deleted.
|
||||
# i.e. they will correctly omit access for the current cc_pair
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
fields = VespaDocumentFields(
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
document_index.update_single(document_id, fields=fields)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
else:
|
||||
pass
|
||||
|
||||
# update_docs_last_modified__no_commit(
|
||||
# db_session=db_session,
|
||||
# document_ids=[document_id],
|
||||
# )
|
||||
|
||||
db_session.commit()
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
@@ -14,6 +14,7 @@ from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from danswer.connectors.connector_runner import ConnectorRunner
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
@@ -29,6 +30,7 @@ from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -48,7 +50,7 @@ def _get_connector_runner(
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
|
||||
Returns an interator of document batches and whether the returned documents
|
||||
Returns an iterator of document batches and whether the returned documents
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
@@ -66,12 +68,17 @@ def _get_connector_runner(
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
# since we failed to even instantiate the connector, we pause the CCPair since
|
||||
# it will never succeed
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector_credential_pair.connector.id,
|
||||
credential_id=attempt.connector_credential_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.PAUSED,
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
attempt.connector_credential_pair.id, db_session
|
||||
)
|
||||
if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector_credential_pair.connector.id,
|
||||
credential_id=attempt.connector_credential_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.PAUSED,
|
||||
)
|
||||
raise e
|
||||
|
||||
return ConnectorRunner(
|
||||
@@ -103,15 +110,24 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
search_settings=search_settings,
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=index_attempt.id,
|
||||
db_session=db_session,
|
||||
# let the world know we're still making progress after
|
||||
# every 10 batches
|
||||
freq=10,
|
||||
),
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt.id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE),
|
||||
ignore_time_skip=(
|
||||
index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -96,14 +96,20 @@ def _should_create_new_indexing(
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if connector.id == 0: # Ingestion API
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if not cc_pair.status.is_active() or connector.id == 0:
|
||||
if (
|
||||
not cc_pair.status.is_active()
|
||||
or connector.id == 0
|
||||
or connector.source == DocumentSource.INGESTION_API
|
||||
):
|
||||
return False
|
||||
|
||||
if not last_index:
|
||||
@@ -416,6 +422,7 @@ def update_loop(
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
logger.notice("First inference complete.")
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
@@ -444,6 +451,7 @@ def update_loop(
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
|
||||
logger.notice("Startup complete. Waiting for indexing jobs...")
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
@@ -135,7 +135,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
# defaults to False
|
||||
@@ -164,13 +164,29 @@ REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
|
||||
)
|
||||
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
|
||||
|
||||
# will propagate to both our redis client as well as celery's redis client
|
||||
REDIS_HEALTH_CHECK_INTERVAL = int(os.environ.get("REDIS_HEALTH_CHECK_INTERVAL", 60))
|
||||
|
||||
# our redis client only, not celery's
|
||||
REDIS_POOL_MAX_CONNECTIONS = int(os.environ.get("REDIS_POOL_MAX_CONNECTIONS", 128))
|
||||
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
# should be one of "required", "optional", or "none"
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
|
||||
|
||||
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
|
||||
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#broker-pool-limit
|
||||
# Setting to None may help when there is a proxy in the way closing idle connections
|
||||
CELERY_BROKER_POOL_LIMIT_DEFAULT = 10
|
||||
try:
|
||||
CELERY_BROKER_POOL_LIMIT = int(
|
||||
os.environ.get("CELERY_BROKER_POOL_LIMIT", CELERY_BROKER_POOL_LIMIT_DEFAULT)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
@@ -247,6 +263,10 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
if ignored_tag
|
||||
]
|
||||
# Maximum size for Jira tickets in bytes (default: 100KB)
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
@@ -270,7 +290,7 @@ ALLOW_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
# This is the maximum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
|
||||
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
|
||||
)
|
||||
@@ -334,12 +354,10 @@ INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
|
||||
# exception without aborting the attempt.
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
|
||||
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
# File based Key Value store no longer used
|
||||
DYNAMIC_CONFIG_STORE = "PostgresBackedDynamicConfigStore"
|
||||
|
||||
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
# used to allow the background indexing jobs to use a different embedding
|
||||
# model server than the API server
|
||||
@@ -388,3 +406,7 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
ENTERPRISE_EDITION_ENABLED = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
|
||||
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import platform
|
||||
import socket
|
||||
from enum import auto
|
||||
from enum import Enum
|
||||
|
||||
@@ -34,9 +36,12 @@ POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
POSTGRES_CELERY_APP_NAME = "celery"
|
||||
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
|
||||
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
POSTGRES_DEFAULT_SCHEMA = "public"
|
||||
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
@@ -46,6 +51,7 @@ UNNAMED_KEY_PLACEHOLDER = "Unnamed"
|
||||
# Key-Value store keys
|
||||
KV_REINDEX_KEY = "needs_reindexing"
|
||||
KV_SEARCH_SETTINGS = "search_settings"
|
||||
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
|
||||
KV_USER_STORE_KEY = "INVITED_USERS"
|
||||
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
|
||||
KV_CRED_KEY = "credential_id_{}"
|
||||
@@ -62,11 +68,13 @@ KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
INGESTION_API = "ingestion_api"
|
||||
FRESHDESK = "freshdesk"
|
||||
SLACK = "slack"
|
||||
WEB = "web"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
@@ -104,6 +112,7 @@ class DocumentSource(str, Enum):
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
@@ -186,6 +195,7 @@ class DanswerCeleryQueues:
|
||||
|
||||
|
||||
class DanswerRedisLocks:
|
||||
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
@@ -198,3 +208,13 @@ class DanswerCeleryPriority(int, Enum):
|
||||
MEDIUM = auto()
|
||||
LOW = auto()
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
|
||||
else:
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore
|
||||
|
||||
@@ -194,8 +194,8 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
try:
|
||||
text = extract_file_text(
|
||||
name,
|
||||
BytesIO(downloaded_file),
|
||||
file_name=name,
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
batch.append(
|
||||
|
||||
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import bs4
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
Args:
|
||||
base_url (str): The base url of the Confluence instance
|
||||
content_url (str): The url of the page or attachment download url
|
||||
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def get_used_attachments(text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
|
||||
Returns:
|
||||
list[str]: List of filenames currently in use by the page text
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
@@ -22,6 +22,10 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.confluence.confluence_utils import (
|
||||
build_confluence_document_id,
|
||||
)
|
||||
from danswer.connectors.confluence.confluence_utils import get_used_attachments
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
@@ -105,24 +109,6 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def get_used_attachments(text: str, confluence_client: Confluence) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
Returns:
|
||||
list[str]: List of filename currently in used
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
|
||||
|
||||
def _comment_dfs(
|
||||
comments_str: str,
|
||||
comment_pages: Collection[dict[str, Any]],
|
||||
@@ -533,7 +519,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
attachment["title"], io.BytesIO(response.content), False
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
@@ -624,19 +612,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
page_html = (
|
||||
page["body"].get("storage", page["body"].get("view", {})).get("value")
|
||||
)
|
||||
page_url = self.wiki_base + page["_links"]["webui"]
|
||||
# The url and the id are the same
|
||||
page_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"]
|
||||
)
|
||||
if not page_html:
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
|
||||
files_in_used = get_used_attachments(page_html, self.confluence_client)
|
||||
files_in_used = get_used_attachments(page_html)
|
||||
attachment_text, unused_page_attachments = self._fetch_attachments(
|
||||
self.confluence_client, page_id, files_in_used
|
||||
)
|
||||
unused_attachments.extend(unused_page_attachments)
|
||||
|
||||
page_text += attachment_text
|
||||
page_text += "\n" + attachment_text if attachment_text else ""
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
|
||||
@@ -683,8 +674,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if time_filter and not time_filter(last_updated):
|
||||
continue
|
||||
|
||||
attachment_url = self._attachment_to_download_link(
|
||||
self.confluence_client, attachment
|
||||
# The url and the id are the same
|
||||
attachment_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["download"]
|
||||
)
|
||||
attachment_content = self._attachment_to_content(
|
||||
self.confluence_client, attachment
|
||||
|
||||
@@ -50,6 +50,12 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
if retry_after > 600:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
|
||||
)
|
||||
retry_after = max_delay
|
||||
|
||||
logger.warning(
|
||||
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from jira.resources import Issue
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@@ -134,10 +135,18 @@ def fetch_jira_issues_batch(
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{description}\n" + "\n".join(
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {jira.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
continue
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
people = set()
|
||||
@@ -180,7 +189,7 @@ def fetch_jira_issues_batch(
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=semantic_rep)],
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=jira.fields.summary,
|
||||
doc_updated_at=time_str_to_utc(jira.fields.updated),
|
||||
@@ -236,10 +245,12 @@ class JiraConnector(LoadConnector, PollConnector):
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=f"project = {self.jira_project}",
|
||||
jql=f"project = {quoted_project}",
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
@@ -267,8 +278,10 @@ class JiraConnector(LoadConnector, PollConnector):
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
jql = (
|
||||
f"project = {self.jira_project} AND "
|
||||
f"project = {quoted_project} AND "
|
||||
f"updated >= '{start_date_str}' AND "
|
||||
f"updated <= '{end_date_str}'"
|
||||
)
|
||||
|
||||
@@ -97,8 +97,8 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
link = self._get_shared_link(entry.path_display)
|
||||
try:
|
||||
text = extract_file_text(
|
||||
entry.name,
|
||||
BytesIO(downloaded_file),
|
||||
file_name=entry.name,
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
batch.append(
|
||||
|
||||
@@ -15,6 +15,7 @@ from danswer.connectors.discourse.connector import DiscourseConnector
|
||||
from danswer.connectors.document360.connector import Document360Connector
|
||||
from danswer.connectors.dropbox.connector import DropboxConnector
|
||||
from danswer.connectors.file.connector import LocalFileConnector
|
||||
from danswer.connectors.freshdesk.connector import FreshdeskConnector
|
||||
from danswer.connectors.github.connector import GithubConnector
|
||||
from danswer.connectors.gitlab.connector import GitlabConnector
|
||||
from danswer.connectors.gmail.connector import GmailConnector
|
||||
@@ -42,6 +43,7 @@ from danswer.connectors.slack.load_connector import SlackLoadConnector
|
||||
from danswer.connectors.teams.connector import TeamsConnector
|
||||
from danswer.connectors.web.connector import WebConnector
|
||||
from danswer.connectors.wikipedia.connector import WikipediaConnector
|
||||
from danswer.connectors.xenforo.connector import XenforoConnector
|
||||
from danswer.connectors.zendesk.connector import ZendeskConnector
|
||||
from danswer.connectors.zulip.connector import ZulipConnector
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
@@ -57,11 +59,13 @@ def identify_connector_class(
|
||||
input_type: InputType | None = None,
|
||||
) -> Type[BaseConnector]:
|
||||
connector_map = {
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.WEB: WebConnector,
|
||||
DocumentSource.FILE: LocalFileConnector,
|
||||
DocumentSource.SLACK: {
|
||||
InputType.LOAD_STATE: SlackLoadConnector,
|
||||
InputType.POLL: SlackPollConnector,
|
||||
InputType.PRUNE: SlackPollConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
@@ -97,6 +101,7 @@ def identify_connector_class(
|
||||
DocumentSource.R2: BlobStorageConnector,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -74,13 +74,14 @@ def _process_file(
|
||||
)
|
||||
|
||||
# Using the PDF reader function directly to pass in password cleanly
|
||||
elif extension == ".pdf":
|
||||
elif extension == ".pdf" and pdf_pass is not None:
|
||||
file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass)
|
||||
|
||||
else:
|
||||
file_content_raw = extract_file_text(
|
||||
file_name=file_name,
|
||||
file=file,
|
||||
file_name=file_name,
|
||||
break_on_unprocessable=True,
|
||||
)
|
||||
|
||||
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
|
||||
|
||||
160
backend/danswer/connectors/freshdesk/connector.py
Normal file
160
backend/danswer/connectors/freshdesk/connector.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup # Add this import for HTML parsing
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class FreshdeskConnector(PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
domain: str | None = None,
|
||||
password: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.domain = domain
|
||||
self.password = password
|
||||
self.batch_size = batch_size
|
||||
|
||||
def ticket_link(self, tid: int) -> str:
|
||||
return f"https://{self.domain}.freshdesk.com/helpdesk/tickets/{tid}"
|
||||
|
||||
def build_doc_sections_from_ticket(self, ticket: dict) -> List[Section]:
|
||||
# Use list comprehension for building sections
|
||||
return [
|
||||
Section(
|
||||
link=self.ticket_link(int(ticket["id"])),
|
||||
text=json.dumps(
|
||||
{
|
||||
key: value
|
||||
for key, value in ticket.items()
|
||||
if isinstance(value, str)
|
||||
},
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
def strip_html_tags(self, html: str) -> str:
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
return soup.get_text()
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
logger.info("Loading credentials")
|
||||
self.api_key = credentials.get("freshdesk_api_key")
|
||||
self.domain = credentials.get("freshdesk_domain")
|
||||
self.password = credentials.get("freshdesk_password")
|
||||
return None
|
||||
|
||||
def _process_tickets(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
assert self.api_key is not None
|
||||
assert self.domain is not None
|
||||
assert self.password is not None
|
||||
|
||||
logger.info("Processing tickets")
|
||||
if any([self.api_key, self.domain, self.password]) is None:
|
||||
raise ConnectorMissingCredentialError("freshdesk")
|
||||
|
||||
freshdesk_url = (
|
||||
f"https://{self.domain}.freshdesk.com/api/v2/tickets?include=description"
|
||||
)
|
||||
response = requests.get(freshdesk_url, auth=(self.api_key, self.password))
|
||||
response.raise_for_status() # raises exception when not a 2xx response
|
||||
|
||||
if response.status_code != 204:
|
||||
tickets = json.loads(response.content)
|
||||
logger.info(f"Fetched {len(tickets)} tickets from Freshdesk API")
|
||||
doc_batch: List[Document] = []
|
||||
|
||||
for ticket in tickets:
|
||||
# Convert the "created_at", "updated_at", and "due_by" values to ISO 8601 strings
|
||||
for date_field in ["created_at", "updated_at", "due_by"]:
|
||||
ticket[date_field] = datetime.fromisoformat(
|
||||
ticket[date_field]
|
||||
).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Convert all other values to strings
|
||||
ticket = {
|
||||
key: str(value) if not isinstance(value, str) else value
|
||||
for key, value in ticket.items()
|
||||
}
|
||||
|
||||
# Checking for overdue tickets
|
||||
today = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
ticket["overdue"] = "true" if today > ticket["due_by"] else "false"
|
||||
|
||||
# Mapping the status field values
|
||||
status_mapping = {2: "open", 3: "pending", 4: "resolved", 5: "closed"}
|
||||
ticket["status"] = status_mapping.get(
|
||||
ticket["status"], str(ticket["status"])
|
||||
)
|
||||
|
||||
# Stripping HTML tags from the description field
|
||||
ticket["description"] = self.strip_html_tags(ticket["description"])
|
||||
|
||||
# Remove extra white spaces from the description field
|
||||
ticket["description"] = " ".join(ticket["description"].split())
|
||||
|
||||
# Use list comprehension for building sections
|
||||
sections = self.build_doc_sections_from_ticket(ticket)
|
||||
|
||||
created_at = datetime.fromisoformat(ticket["created_at"])
|
||||
|
||||
doc = Document(
|
||||
id=ticket["id"],
|
||||
sections=sections,
|
||||
source=DocumentSource.FRESHDESK,
|
||||
semantic_identifier=ticket["subject"],
|
||||
metadata={
|
||||
key: value
|
||||
for key, value in ticket.items()
|
||||
if isinstance(value, str)
|
||||
and key not in ["description", "description_text"]
|
||||
},
|
||||
doc_updated_at=created_at,
|
||||
)
|
||||
|
||||
doc_batch.append(doc)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
yield from self._process_tickets(start, end)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = FreshdeskConnector(
|
||||
api_key=os.environ.get("FRESHDESK_API_KEY"),
|
||||
domain=os.environ.get("FRESHDESK_DOMAIN"),
|
||||
password=os.environ.get("FRESHDESK_PASSWORD"),
|
||||
)
|
||||
for doc in connector.poll_source(start=None, end=None):
|
||||
print(doc)
|
||||
@@ -25,7 +25,7 @@ from danswer.connectors.gmail.constants import (
|
||||
from danswer.connectors.gmail.constants import SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import GoogleAppCredentials
|
||||
from danswer.server.documents.models import GoogleServiceAccountKey
|
||||
@@ -72,7 +72,7 @@ def get_gmail_creds_for_service_account(
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Gmail Connector callback does not match expected"
|
||||
@@ -80,7 +80,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
|
||||
|
||||
|
||||
def get_gmail_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
@@ -92,14 +92,14 @@ def get_gmail_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
@@ -111,7 +111,7 @@ def get_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
@@ -158,42 +158,40 @@ def build_service_account_creds(
|
||||
|
||||
|
||||
def get_google_app_gmail_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
|
||||
|
||||
def delete_google_app_gmail_cred() -> None:
|
||||
get_dynamic_config_store().delete(KV_GMAIL_CRED_KEY)
|
||||
get_kv_store().delete(KV_GMAIL_CRED_KEY)
|
||||
|
||||
|
||||
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_gmail_service_account_key(
|
||||
service_account_key: GoogleServiceAccountKey,
|
||||
) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_gmail_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
@@ -36,6 +36,8 @@ from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import docx_to_text
|
||||
from danswer.file_processing.extract_file_text import pptx_to_text
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -327,16 +329,24 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
elif mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
GDriveMimeType.PDF.value,
|
||||
]:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return text
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
if get_unstructured_api_key():
|
||||
return unstructured_to_text(
|
||||
file=io.BytesIO(response), file_name=file.get("name", file["id"])
|
||||
)
|
||||
|
||||
if mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return text
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import GoogleAppCredentials
|
||||
from danswer.server.documents.models import GoogleServiceAccountKey
|
||||
@@ -134,7 +134,7 @@ def get_google_drive_creds(
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
@@ -142,7 +142,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
@@ -154,7 +154,7 @@ def get_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
@@ -202,32 +202,28 @@ def build_service_account_creds(
|
||||
|
||||
|
||||
def get_google_app_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
get_kv_store().store(KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
|
||||
|
||||
def delete_google_app_cred() -> None:
|
||||
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
|
||||
|
||||
def get_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(
|
||||
get_dynamic_config_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
@@ -40,8 +40,8 @@ def _convert_driveitem_to_document(
|
||||
driveitem: DriveItem,
|
||||
) -> Document:
|
||||
file_text = extract_file_text(
|
||||
file_name=driveitem.name,
|
||||
file=io.BytesIO(driveitem.get_content().execute_query().value),
|
||||
file_name=driveitem.name,
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,13 +8,12 @@ from typing import cast
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
@@ -23,9 +22,8 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.connectors.slack.utils import get_message_link
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_logged
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_paginated
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -38,47 +36,18 @@ MessageType = dict[str, Any]
|
||||
# list of messages in a thread
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
basic_retry_wrapper = retry_builder()
|
||||
|
||||
|
||||
def _make_paginated_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
||||
)
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def _make_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
|
||||
"""Get information about a channel. Needed to convert channel ID to channel name"""
|
||||
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
|
||||
"channel"
|
||||
]
|
||||
|
||||
|
||||
def _get_channels(
|
||||
def _collect_paginated_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool,
|
||||
get_private: bool,
|
||||
channel_types: list[str],
|
||||
) -> list[ChannelType]:
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_list,
|
||||
exclude_archived=exclude_archived,
|
||||
# also get private channels the bot is added to
|
||||
types=["public_channel", "private_channel"]
|
||||
if get_private
|
||||
else ["public_channel"],
|
||||
types=channel_types,
|
||||
):
|
||||
channels.extend(result["channels"])
|
||||
|
||||
@@ -88,19 +57,38 @@ def _get_channels(
|
||||
def get_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool = True,
|
||||
get_public: bool = True,
|
||||
get_private: bool = True,
|
||||
) -> list[ChannelType]:
|
||||
"""Get all channels in the workspace"""
|
||||
channels: list[dict[str, Any]] = []
|
||||
channel_types = []
|
||||
if get_public:
|
||||
channel_types.append("public_channel")
|
||||
if get_private:
|
||||
channel_types.append("private_channel")
|
||||
# try getting private channels as well at first
|
||||
try:
|
||||
return _get_channels(
|
||||
client=client, exclude_archived=exclude_archived, get_private=True
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.info(f"Unable to fetch private channels due to - {e}")
|
||||
logger.info("trying again without private channels")
|
||||
if get_public:
|
||||
channel_types = ["public_channel"]
|
||||
else:
|
||||
logger.warning("No channels to fetch")
|
||||
return []
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
|
||||
return _get_channels(
|
||||
client=client, exclude_archived=exclude_archived, get_private=False
|
||||
)
|
||||
return channels
|
||||
|
||||
|
||||
def get_channel_messages(
|
||||
@@ -112,14 +100,14 @@ def get_channel_messages(
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
if not channel["is_member"]:
|
||||
_make_slack_api_call(
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
@@ -131,7 +119,7 @@ def get_channel_messages(
|
||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||
"""Get all messages in a thread"""
|
||||
threads: list[MessageType] = []
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||
):
|
||||
threads.extend(result["messages"])
|
||||
@@ -217,12 +205,17 @@ _DISALLOWED_MSG_SUBTYPES = {
|
||||
"group_leave",
|
||||
"group_archive",
|
||||
"group_unarchive",
|
||||
"channel_leave",
|
||||
"channel_name",
|
||||
"channel_join",
|
||||
}
|
||||
|
||||
|
||||
def _default_msg_filter(message: MessageType) -> bool:
|
||||
def default_msg_filter(message: MessageType) -> bool:
|
||||
# Don't keep messages from bots
|
||||
if message.get("bot_id") or message.get("app_id"):
|
||||
if message.get("bot_profile", {}).get("name") == "DanswerConnector":
|
||||
return False
|
||||
return True
|
||||
|
||||
# Uninformative
|
||||
@@ -266,14 +259,14 @@ def filter_channels(
|
||||
]
|
||||
|
||||
|
||||
def get_all_docs(
|
||||
def _get_all_docs(
|
||||
client: WebClient,
|
||||
workspace: str,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Get all documents in the workspace, channel by channel"""
|
||||
slack_cleaner = SlackTextCleaner(client=client)
|
||||
@@ -328,7 +321,44 @@ def get_all_docs(
|
||||
)
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector):
|
||||
def _get_all_doc_ids(
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
|
||||
This makes it an order of magnitude faster than get_all_docs
|
||||
"""
|
||||
|
||||
all_channels = get_channels(client)
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
all_doc_ids = set()
|
||||
for channel in filtered_channels:
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
if msg_filter_func(message):
|
||||
continue
|
||||
|
||||
# The document id is the channel id and the ts of the first message in the thread
|
||||
# Since we already have the first message of the thread, we dont have to
|
||||
# fetch the thread for id retrieval, saving time and API calls
|
||||
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
|
||||
|
||||
return all_doc_ids
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector, IdConnector):
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
@@ -349,6 +379,16 @@ class SlackPollConnector(PollConnector):
|
||||
self.client = WebClient(token=bot_token)
|
||||
return None
|
||||
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
return _get_all_doc_ids(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
@@ -356,7 +396,7 @@ class SlackPollConnector(PollConnector):
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
documents: list[Document] = []
|
||||
for document in get_all_docs(
|
||||
for document in _get_all_docs(
|
||||
client=self.client,
|
||||
workspace=self.workspace,
|
||||
channels=self.channels,
|
||||
|
||||
@@ -10,11 +10,13 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_retry_wrapper = retry_builder()
|
||||
# number of messages we request per page when fetching paginated slack messages
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
@@ -34,7 +36,7 @@ def get_message_link(
|
||||
)
|
||||
|
||||
|
||||
def make_slack_api_call_logged(
|
||||
def _make_slack_api_call_logged(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., SlackResponse]:
|
||||
@wraps(call)
|
||||
@@ -47,7 +49,7 @@ def make_slack_api_call_logged(
|
||||
return logged_call
|
||||
|
||||
|
||||
def make_slack_api_call_paginated(
|
||||
def _make_slack_api_call_paginated(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., Generator[dict[str, Any], None, None]]:
|
||||
"""Wraps calls to slack API so that they automatically handle pagination"""
|
||||
@@ -116,6 +118,24 @@ def make_slack_api_rate_limited(
|
||||
return rate_limited_call
|
||||
|
||||
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def expert_info_from_slack_id(
|
||||
user_id: str | None,
|
||||
client: WebClient,
|
||||
|
||||
244
backend/danswer/connectors/xenforo/connector.py
Normal file
244
backend/danswer/connectors/xenforo/connector.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
This is the XenforoConnector class. It is used to connect to a Xenforo forum and load or update documents from the forum.
|
||||
|
||||
To use this class, you need to provide the URL of the Xenforo forum board you want to connect to when creating an instance
|
||||
of the class. The URL should be a string that starts with 'http://' or 'https://', followed by the domain name of the
|
||||
forum, followed by the board name. For example:
|
||||
|
||||
base_url = 'https://www.example.com/forum/boards/some-topic/'
|
||||
|
||||
The `load_from_state` method is used to load documents from the forum. It takes an optional `state` parameter, which
|
||||
can be used to specify a state from which to start loading documents.
|
||||
"""
|
||||
import re
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytz
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import Tag
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_title(soup: BeautifulSoup) -> str:
|
||||
el = soup.find("h1", "p-title-value")
|
||||
if not el:
|
||||
return ""
|
||||
title = el.text
|
||||
for char in (";", ":", "!", "*", "/", "\\", "?", '"', "<", ">", "|"):
|
||||
title = title.replace(char, "_")
|
||||
return title
|
||||
|
||||
|
||||
def get_pages(soup: BeautifulSoup, url: str) -> list[str]:
|
||||
page_tags = soup.select("li.pageNav-page")
|
||||
page_numbers = []
|
||||
for button in page_tags:
|
||||
if re.match(r"^\d+$", button.text):
|
||||
page_numbers.append(button.text)
|
||||
|
||||
max_pages = int(max(page_numbers, key=int)) if page_numbers else 1
|
||||
|
||||
all_pages = []
|
||||
for x in range(1, int(max_pages) + 1):
|
||||
all_pages.append(f"{url}page-{x}")
|
||||
return all_pages
|
||||
|
||||
|
||||
def parse_post_date(post_element: BeautifulSoup) -> datetime:
|
||||
el = post_element.find("time")
|
||||
if not isinstance(el, Tag) or "datetime" not in el.attrs:
|
||||
return datetime.utcfromtimestamp(0).replace(tzinfo=timezone.utc)
|
||||
|
||||
date_value = el["datetime"]
|
||||
|
||||
# Ensure date_value is a string (if it's a list, take the first element)
|
||||
if isinstance(date_value, list):
|
||||
date_value = date_value[0]
|
||||
|
||||
post_date = datetime.strptime(date_value, "%Y-%m-%dT%H:%M:%S%z")
|
||||
return datetime_to_utc(post_date)
|
||||
|
||||
|
||||
def scrape_page_posts(
|
||||
soup: BeautifulSoup,
|
||||
page_index: int,
|
||||
url: str,
|
||||
initial_run: bool,
|
||||
start_time: datetime,
|
||||
) -> list:
|
||||
title = get_title(soup)
|
||||
|
||||
documents = []
|
||||
for post in soup.find_all("div", class_="message-inner"):
|
||||
post_date = parse_post_date(post)
|
||||
if initial_run or post_date > start_time:
|
||||
el = post.find("div", class_="bbWrapper")
|
||||
if not el:
|
||||
continue
|
||||
post_text = el.get_text(strip=True) + "\n"
|
||||
author_tag = post.find("a", class_="username")
|
||||
if author_tag is None:
|
||||
author_tag = post.find("span", class_="username")
|
||||
author = author_tag.get_text(strip=True) if author_tag else "Deleted author"
|
||||
formatted_time = post_date.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# TODO: if a caller calls this for each page of a thread, it may see the
|
||||
# same post multiple times if there is a sticky post
|
||||
# that appears on each page of a thread.
|
||||
# it's important to generate unique doc id's, so page index is part of the
|
||||
# id. We may want to de-dupe this stuff inside the indexing service.
|
||||
document = Document(
|
||||
id=f"{DocumentSource.XENFORO.value}_{title}_{page_index}_{formatted_time}",
|
||||
sections=[Section(link=url, text=post_text)],
|
||||
title=title,
|
||||
source=DocumentSource.XENFORO,
|
||||
semantic_identifier=title,
|
||||
primary_owners=[BasicExpertInfo(display_name=author)],
|
||||
metadata={
|
||||
"type": "post",
|
||||
"author": author,
|
||||
"time": formatted_time,
|
||||
},
|
||||
doc_updated_at=post_date,
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
return documents
|
||||
|
||||
|
||||
class XenforoConnector(LoadConnector):
|
||||
# Class variable to track if the connector has been run before
|
||||
has_been_run_before = False
|
||||
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url = base_url
|
||||
self.initial_run = not XenforoConnector.has_been_run_before
|
||||
self.start = datetime.utcnow().replace(tzinfo=pytz.utc) - timedelta(days=1)
|
||||
self.cookies: dict[str, str] = {}
|
||||
# mimic user browser to avoid being blocked by the website (see: https://www.useragents.me/)
|
||||
self.headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/121.0.0.0 Safari/537.36"
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if credentials:
|
||||
logger.warning("Unexpected credentials provided for Xenforo Connector")
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
# Standardize URL to always end in /.
|
||||
if self.base_url[-1] != "/":
|
||||
self.base_url += "/"
|
||||
|
||||
# Remove all extra parameters from the end such as page, post.
|
||||
matches = ("threads/", "boards/", "forums/")
|
||||
for each in matches:
|
||||
if each in self.base_url:
|
||||
try:
|
||||
self.base_url = self.base_url[
|
||||
0 : self.base_url.index(
|
||||
"/", self.base_url.index(each) + len(each)
|
||||
)
|
||||
+ 1
|
||||
]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
all_threads = []
|
||||
|
||||
# If the URL contains "boards/" or "forums/", find all threads.
|
||||
if "boards/" in self.base_url or "forums/" in self.base_url:
|
||||
pages = get_pages(self.requestsite(self.base_url), self.base_url)
|
||||
|
||||
# Get all pages on thread_list_page
|
||||
for pre_count, thread_list_page in enumerate(pages, start=1):
|
||||
logger.info(
|
||||
f"Getting pages from thread_list_page.. Current: {pre_count}/{len(pages)}\r"
|
||||
)
|
||||
all_threads += self.get_threads(thread_list_page)
|
||||
# If the URL contains "threads/", add the thread to the list.
|
||||
elif "threads/" in self.base_url:
|
||||
all_threads.append(self.base_url)
|
||||
|
||||
# Process all threads
|
||||
for thread_count, thread_url in enumerate(all_threads, start=1):
|
||||
soup = self.requestsite(thread_url)
|
||||
if soup is None:
|
||||
logger.error(f"Failed to load page: {self.base_url}")
|
||||
continue
|
||||
pages = get_pages(soup, thread_url)
|
||||
# Getting all pages for all threads
|
||||
for page_index, page in enumerate(pages, start=1):
|
||||
logger.info(
|
||||
f"Progress: Page {page_index}/{len(pages)} - Thread {thread_count}/{len(all_threads)}\r"
|
||||
)
|
||||
soup_page = self.requestsite(page)
|
||||
doc_batch.extend(
|
||||
scrape_page_posts(
|
||||
soup_page, page_index, thread_url, self.initial_run, self.start
|
||||
)
|
||||
)
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
# Mark the initial run finished after all threads and pages have been processed
|
||||
XenforoConnector.has_been_run_before = True
|
||||
|
||||
def get_threads(self, url: str) -> list[str]:
|
||||
soup = self.requestsite(url)
|
||||
thread_tags = soup.find_all(class_="structItem-title")
|
||||
base_url = "{uri.scheme}://{uri.netloc}".format(uri=urlparse(url))
|
||||
threads = []
|
||||
for x in thread_tags:
|
||||
y = x.find_all(href=True)
|
||||
for element in y:
|
||||
link = element["href"]
|
||||
if "threads/" in link:
|
||||
stripped = link[0 : link.rfind("/") + 1]
|
||||
if base_url + stripped not in threads:
|
||||
threads.append(base_url + stripped)
|
||||
return threads
|
||||
|
||||
def requestsite(self, url: str) -> BeautifulSoup:
|
||||
try:
|
||||
response = requests.get(
|
||||
url, cookies=self.cookies, headers=self.headers, timeout=10
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"<{url}> Request Error: {response.status_code} - {response.reason}"
|
||||
)
|
||||
return BeautifulSoup(response.text, "html.parser")
|
||||
except TimeoutError:
|
||||
logger.error("Timed out Error.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error on {url}")
|
||||
logger.exception(e)
|
||||
return BeautifulSoup("", "html.parser")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = XenforoConnector(
|
||||
# base_url="https://cassiopaea.org/forum/threads/how-to-change-your-emotional-state.41381/"
|
||||
base_url="https://xenforo.com/community/threads/whats-new-with-enhanced-search-resource-manager-and-media-gallery-in-xenforo-2-3.220935/"
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
@@ -160,7 +160,7 @@ def handle_regular_answer(
|
||||
detail="Slack bot does not support persona config",
|
||||
)
|
||||
|
||||
elif new_message_request.persona_id:
|
||||
elif new_message_request.persona_id is not None:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
|
||||
@@ -49,7 +49,7 @@ from danswer.danswerbot.slack.utils import rephrase_slack_message
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
@@ -522,7 +522,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Let the handlers run in the background + re-check for token updates every 60 seconds
|
||||
Event().wait(timeout=60)
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
# try again every 30 seconds. This is needed since the user may add tokens
|
||||
# via the UI at any point in the programs lifecycle - if we just allow it to
|
||||
# fail, then the user will need to restart the containers after adding tokens
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_SLACK_BOT_TOKENS_CONFIG_KEY
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ def fetch_tokens() -> SlackBotTokens:
|
||||
if app_token and bot_token:
|
||||
return SlackBotTokens(app_token=app_token, bot_token=bot_token)
|
||||
|
||||
dynamic_config_store = get_dynamic_config_store()
|
||||
dynamic_config_store = get_kv_store()
|
||||
return SlackBotTokens(
|
||||
**cast(dict, dynamic_config_store.load(key=KV_SLACK_BOT_TOKENS_CONFIG_KEY))
|
||||
)
|
||||
@@ -22,7 +22,7 @@ def fetch_tokens() -> SlackBotTokens:
|
||||
def save_tokens(
|
||||
tokens: SlackBotTokens,
|
||||
) -> None:
|
||||
dynamic_config_store = get_dynamic_config_store()
|
||||
dynamic_config_store = get_kv_store()
|
||||
dynamic_config_store.store(
|
||||
key=KV_SLACK_BOT_TOKENS_CONFIG_KEY, val=dict(tokens), encrypt=True
|
||||
)
|
||||
|
||||
@@ -26,9 +26,7 @@ from danswer.db.models import UserRole
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
check_if_valid_sync_source,
|
||||
)
|
||||
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -104,6 +104,18 @@ def construct_document_select_for_connector_credential_pair(
|
||||
return stmt
|
||||
|
||||
|
||||
def get_document_ids_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> list[str]:
|
||||
doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
return list(db_session.execute(doc_ids_stmt).scalars().all())
|
||||
|
||||
|
||||
def get_documents_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> Sequence[DbDocument]:
|
||||
@@ -120,8 +132,8 @@ def get_documents_for_connector_credential_pair(
|
||||
|
||||
|
||||
def get_documents_by_ids(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> list[DbDocument]:
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
|
||||
documents = db_session.execute(stmt).scalars().all()
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
import contextlib
|
||||
import contextvars
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import ContextManager
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import create_engine
|
||||
@@ -17,6 +25,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import POSTGRES_DB
|
||||
from danswer.configs.app_configs import POSTGRES_HOST
|
||||
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
||||
@@ -24,27 +33,24 @@ from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
|
||||
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
SYNC_DB_API = "psycopg2"
|
||||
ASYNC_DB_API = "asyncpg"
|
||||
|
||||
POSTGRES_APP_NAME = (
|
||||
POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres
|
||||
)
|
||||
|
||||
# global so we don't create more than one engine per process
|
||||
# outside of being best practice, this is needed so we can properly pool
|
||||
# connections and not create a new pool on every request
|
||||
_SYNC_ENGINE: Engine | None = None
|
||||
|
||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||
|
||||
SessionFactory: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
if LOG_POSTGRES_LATENCY:
|
||||
# Function to log before query execution
|
||||
@event.listens_for(Engine, "before_cursor_execute")
|
||||
@@ -108,6 +114,78 @@ def get_db_current_time(db_session: Session) -> datetime:
|
||||
return result
|
||||
|
||||
|
||||
# Regular expression to validate schema names to prevent SQL injection
|
||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||
|
||||
|
||||
def is_valid_schema_name(name: str) -> bool:
|
||||
return SCHEMA_NAME_REGEX.match(name) is not None
|
||||
|
||||
|
||||
class SqlEngine:
|
||||
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
|
||||
Will eventually subsume most of the standalone functions in this file.
|
||||
Sync only for now.
|
||||
"""
|
||||
|
||||
_engine: Engine | None = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
|
||||
# Default parameters for engine creation
|
||||
DEFAULT_ENGINE_KWARGS = {
|
||||
"pool_size": 40,
|
||||
"max_overflow": 10,
|
||||
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
|
||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||
"""Private helper method to create and return an Engine."""
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
|
||||
)
|
||||
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
|
||||
return create_engine(connection_string, **merged_kwargs)
|
||||
|
||||
@classmethod
|
||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||
"""Allow the caller to init the engine with extra params. Different clients
|
||||
such as the API server and different Celery workers and tasks
|
||||
need different settings.
|
||||
"""
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine(**engine_kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
|
||||
already been called. You probably want to init first!
|
||||
"""
|
||||
if not cls._engine:
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine()
|
||||
return cls._engine
|
||||
|
||||
@classmethod
|
||||
def set_app_name(cls, app_name: str) -> None:
|
||||
"""Class method to set the app name."""
|
||||
cls._app_name = app_name
|
||||
|
||||
@classmethod
|
||||
def get_app_name(cls) -> str:
|
||||
"""Class method to get current app name."""
|
||||
if not cls._app_name:
|
||||
return ""
|
||||
return cls._app_name
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
@@ -120,69 +198,142 @@ def build_connection_string(
|
||||
) -> str:
|
||||
if app_name:
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
|
||||
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
|
||||
def init_sqlalchemy_engine(app_name: str) -> None:
|
||||
global POSTGRES_APP_NAME
|
||||
POSTGRES_APP_NAME = app_name
|
||||
SqlEngine.set_app_name(app_name)
|
||||
|
||||
|
||||
def get_sqlalchemy_engine() -> Engine:
|
||||
global _SYNC_ENGINE
|
||||
if _SYNC_ENGINE is None:
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
|
||||
)
|
||||
_SYNC_ENGINE = create_engine(
|
||||
connection_string,
|
||||
pool_size=5,
|
||||
max_overflow=0,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
return _SYNC_ENGINE
|
||||
return SqlEngine.get_engine()
|
||||
|
||||
|
||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
global _ASYNC_ENGINE
|
||||
if _ASYNC_ENGINE is None:
|
||||
# underlying asyncpg cannot accept application_name directly in the connection string
|
||||
# Underlying asyncpg cannot accept application_name directly in the connection string
|
||||
# https://github.com/MagicStack/asyncpg/issues/798
|
||||
connection_string = build_connection_string()
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
connection_string,
|
||||
connect_args={
|
||||
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
|
||||
"server_settings": {
|
||||
"application_name": SqlEngine.get_app_name() + "_async"
|
||||
}
|
||||
},
|
||||
pool_size=5,
|
||||
max_overflow=0,
|
||||
pool_size=40,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
def get_session_context_manager() -> ContextManager[Session]:
|
||||
return contextlib.contextmanager(get_session)()
|
||||
# Context variable to store the current tenant ID
|
||||
# This allows us to maintain tenant-specific context throughout the request lifecycle
|
||||
# The default value is set to POSTGRES_DEFAULT_SCHEMA for non-multi-tenant setups
|
||||
# This context variable works in both synchronous and asynchronous contexts
|
||||
# In async code, it's automatically carried across coroutines
|
||||
# In sync code, it's managed per thread
|
||||
current_tenant_id = contextvars.ContextVar(
|
||||
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
# The line below was added to monitor the latency caused by Postgres connections
|
||||
# during API calls.
|
||||
# with tracer.trace("db.get_session"):
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
|
||||
# Dependency to get the current tenant ID and set the context variable
|
||||
def get_current_tenant_id(request: Request) -> str:
|
||||
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
current_tenant_id.set(tenant_id)
|
||||
return tenant_id
|
||||
|
||||
token = request.cookies.get("tenant_details")
|
||||
if not token:
|
||||
# If no token is present, use the default schema or handle accordingly
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
current_tenant_id.set(tenant_id)
|
||||
return tenant_id
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
|
||||
tenant_id = payload.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid token: tenant_id missing"
|
||||
)
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise ValueError("Invalid tenant ID format")
|
||||
current_tenant_id.set(tenant_id)
|
||||
return tenant_id
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(status_code=401, detail="Invalid token format")
|
||||
except ValueError as e:
|
||||
# Let the 400 error bubble up
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
def get_session_with_tenant(tenant_id: str | None = None) -> Session:
|
||||
if tenant_id is None:
|
||||
tenant_id = current_tenant_id.get()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise Exception("Invalid tenant ID")
|
||||
|
||||
engine = SqlEngine.get_engine()
|
||||
session = Session(engine, expire_on_commit=False)
|
||||
|
||||
@event.listens_for(session, "after_begin")
|
||||
def set_search_path(session: Session, transaction: Any, connection: Any) -> None:
|
||||
connection.execute(text("SET search_path TO :schema"), {"schema": tenant_id})
|
||||
|
||||
return session
|
||||
|
||||
|
||||
def get_session(
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield session
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with AsyncSession(
|
||||
get_sqlalchemy_async_engine(), expire_on_commit=False
|
||||
) as async_session:
|
||||
async def get_async_session(
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Generate an async database session with the appropriate tenant schema set."""
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield async_session
|
||||
|
||||
|
||||
def get_session_context_manager() -> ContextManager[Session]:
|
||||
"""Context manager for database sessions."""
|
||||
return contextlib.contextmanager(get_session)()
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
"""Get a session factory."""
|
||||
global SessionFactory
|
||||
if SessionFactory is None:
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
return SessionFactory
|
||||
|
||||
|
||||
async def warm_up_connections(
|
||||
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
|
||||
) -> None:
|
||||
@@ -204,10 +355,3 @@ async def warm_up_connections(
|
||||
await async_conn.execute(text("SELECT 1"))
|
||||
for async_conn in async_connections:
|
||||
await async_conn.close()
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
global SessionFactory
|
||||
if SessionFactory is None:
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
return SessionFactory
|
||||
|
||||
@@ -64,19 +64,12 @@ def upsert_cloud_embedding_provider(
|
||||
def upsert_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
is_creation: bool = True,
|
||||
) -> FullLLMProvider:
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider and is_creation:
|
||||
raise ValueError(f"LLM Provider with name {llm_provider.name} already exists")
|
||||
|
||||
if not existing_llm_provider:
|
||||
if not is_creation:
|
||||
raise ValueError(
|
||||
f"LLM Provider with name {llm_provider.name} does not exist"
|
||||
)
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.pydantic_type import PydanticType
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
@@ -1725,7 +1725,9 @@ class User__ExternalUserGroupId(Base):
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
# These group ids have been prefixed by the source type
|
||||
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
|
||||
cc_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class UsageReport(Base):
|
||||
|
||||
@@ -11,7 +11,7 @@ from danswer.db.index_attempt import (
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_search_settings_status
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -54,7 +54,7 @@ def check_index_swap(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
if cc_pair_count > 0:
|
||||
kv_store = get_dynamic_config_store()
|
||||
kv_store = get_kv_store()
|
||||
kv_store.store(KV_REINDEX_KEY, False)
|
||||
|
||||
# Expire jobs for the now past index/embedding model
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
@@ -107,12 +108,14 @@ def create_or_add_document_tag_list(
|
||||
return all_tags
|
||||
|
||||
|
||||
def get_tags_by_value_prefix_for_source_types(
|
||||
def find_tags(
|
||||
tag_key_prefix: str | None,
|
||||
tag_value_prefix: str | None,
|
||||
sources: list[DocumentSource] | None,
|
||||
limit: int | None,
|
||||
db_session: Session,
|
||||
# if set, both tag_key_prefix and tag_value_prefix must be a match
|
||||
require_both_to_match: bool = False,
|
||||
) -> list[Tag]:
|
||||
query = select(Tag)
|
||||
|
||||
@@ -122,7 +125,11 @@ def get_tags_by_value_prefix_for_source_types(
|
||||
conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%"))
|
||||
if tag_value_prefix:
|
||||
conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%"))
|
||||
query = query.where(or_(*conditions))
|
||||
|
||||
final_prefix_condition = (
|
||||
and_(*conditions) if require_both_to_match else or_(*conditions)
|
||||
)
|
||||
query = query.where(final_prefix_condition)
|
||||
|
||||
if sources:
|
||||
query = query.where(Tag.source.in_(sources))
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
|
||||
@@ -13,3 +16,14 @@ def get_default_document_index(
|
||||
return VespaIndex(
|
||||
index_name=primary_index_name, secondary_index_name=secondary_index_name
|
||||
)
|
||||
|
||||
|
||||
def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex:
|
||||
"""
|
||||
TODO: Use redis to cache this or something
|
||||
"""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
|
||||
@@ -55,6 +55,21 @@ class DocumentMetadata:
|
||||
from_ingestion_api: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VespaDocumentFields:
|
||||
"""
|
||||
Specifies fields in Vespa for a document. Fields set to None will be ignored.
|
||||
Perhaps we should name this in an implementation agnostic fashion, but it's more
|
||||
understandable like this for now.
|
||||
"""
|
||||
|
||||
# all other fields except these 4 will always be left alone by the update request
|
||||
access: DocumentAccess | None = None
|
||||
document_sets: set[str] | None = None
|
||||
boost: float | None = None
|
||||
hidden: bool | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateRequest:
|
||||
"""
|
||||
@@ -156,6 +171,16 @@ class Deletable(abc.ABC):
|
||||
Class must implement the ability to delete document by their unique document ids.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_single(self, doc_id: str) -> None:
|
||||
"""
|
||||
Given a single document id, hard delete it from the document index
|
||||
|
||||
Parameters:
|
||||
- doc_id: document id as specified by the connector
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
"""
|
||||
@@ -178,11 +203,9 @@ class Updatable(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_single(self, update_request: UpdateRequest) -> None:
|
||||
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> None:
|
||||
"""
|
||||
Updates some set of chunks for a document. The document and fields to update
|
||||
are specified in the update request. Each update request in the list applies
|
||||
its changes to a list of document ids.
|
||||
Updates all chunks for a document with the specified fields.
|
||||
None values mean that the field does not need an update.
|
||||
|
||||
The rationale for a single update function is that it allows retries and parallelism
|
||||
@@ -190,14 +213,10 @@ class Updatable(abc.ABC):
|
||||
us to individually handle error conditions per document.
|
||||
|
||||
Parameters:
|
||||
- update_request: for a list of document ids in the update request, apply the same updates
|
||||
to all of the documents with those ids.
|
||||
- fields: the fields to update in the document. Any field set to None will not be changed.
|
||||
|
||||
Return:
|
||||
- an HTTPStatus code. The code can used to decide whether to fail immediately,
|
||||
retry, etc. Although this method likely hits an HTTP API behind the
|
||||
scenes, the usage of HTTPStatus is a convenience and the interface is not
|
||||
actually HTTP specific.
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import concurrent.futures
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
@@ -13,6 +14,7 @@ from typing import cast
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.chat_configs import DOC_TIME_DECAY
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -22,6 +24,7 @@ from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import DocumentInsertionRecord
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.document_index.interfaces import VespaChunkRequest
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.document_index.vespa.chunk_retrieval import batch_search_api_retrieval
|
||||
from danswer.document_index.vespa.chunk_retrieval import (
|
||||
get_all_vespa_ids_for_document_id,
|
||||
@@ -58,8 +61,8 @@ from danswer.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
|
||||
from danswer.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
|
||||
from danswer.document_index.vespa_constants import VESPA_TIMEOUT
|
||||
from danswer.document_index.vespa_constants import YQL_BASE
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
from danswer.utils.batching import batch_generator
|
||||
@@ -68,6 +71,10 @@ from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Set the logging level to WARNING to ignore INFO and DEBUG logs
|
||||
httpx_logger = logging.getLogger("httpx")
|
||||
httpx_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _VespaUpdateRequest:
|
||||
@@ -140,7 +147,7 @@ class VespaIndex(DocumentIndex):
|
||||
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
|
||||
)
|
||||
|
||||
kv_store = get_dynamic_config_store()
|
||||
kv_store = get_kv_store()
|
||||
|
||||
needs_reindexing = False
|
||||
try:
|
||||
@@ -377,89 +384,86 @@ class VespaIndex(DocumentIndex):
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
def update_single(self, update_request: UpdateRequest) -> None:
|
||||
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> None:
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
function will complete with no errors or exceptions.
|
||||
Handle other exceptions if you wish to implement retry behavior
|
||||
"""
|
||||
if len(update_request.document_ids) != 1:
|
||||
raise ValueError("update_request must contain a single document id")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
# Mutating update_request but it's not used later anyway
|
||||
update_request.document_ids = [
|
||||
replace_invalid_doc_id_characters(doc_id)
|
||||
for doc_id in update_request.document_ids
|
||||
]
|
||||
|
||||
# update_start = time.monotonic()
|
||||
|
||||
# Fetch all chunks for each document ahead of time
|
||||
index_names = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
chunk_id_start_time = time.monotonic()
|
||||
all_doc_chunk_ids: list[str] = []
|
||||
for index_name in index_names:
|
||||
for document_id in update_request.document_ids:
|
||||
# this calls vespa and can raise http exceptions
|
||||
doc_chunk_ids = get_all_vespa_ids_for_document_id(
|
||||
document_id=document_id,
|
||||
index_name=index_name,
|
||||
filters=None,
|
||||
get_large_chunks=True,
|
||||
)
|
||||
all_doc_chunk_ids.extend(doc_chunk_ids)
|
||||
logger.debug(
|
||||
f"Took {time.monotonic() - chunk_id_start_time:.2f} seconds to fetch all Vespa chunk IDs"
|
||||
)
|
||||
normalized_doc_id = replace_invalid_doc_id_characters(doc_id)
|
||||
|
||||
# Build the _VespaUpdateRequest objects
|
||||
update_dict: dict[str, dict] = {"fields": {}}
|
||||
if update_request.boost is not None:
|
||||
update_dict["fields"][BOOST] = {"assign": update_request.boost}
|
||||
if update_request.document_sets is not None:
|
||||
if fields.boost is not None:
|
||||
update_dict["fields"][BOOST] = {"assign": fields.boost}
|
||||
if fields.document_sets is not None:
|
||||
update_dict["fields"][DOCUMENT_SETS] = {
|
||||
"assign": {
|
||||
document_set: 1 for document_set in update_request.document_sets
|
||||
}
|
||||
"assign": {document_set: 1 for document_set in fields.document_sets}
|
||||
}
|
||||
if update_request.access is not None:
|
||||
if fields.access is not None:
|
||||
update_dict["fields"][ACCESS_CONTROL_LIST] = {
|
||||
"assign": {acl_entry: 1 for acl_entry in update_request.access.to_acl()}
|
||||
"assign": {acl_entry: 1 for acl_entry in fields.access.to_acl()}
|
||||
}
|
||||
if update_request.hidden is not None:
|
||||
update_dict["fields"][HIDDEN] = {"assign": update_request.hidden}
|
||||
if fields.hidden is not None:
|
||||
update_dict["fields"][HIDDEN] = {"assign": fields.hidden}
|
||||
|
||||
if not update_dict["fields"]:
|
||||
logger.error("Update request received but nothing to update")
|
||||
return
|
||||
|
||||
processed_update_requests: list[_VespaUpdateRequest] = []
|
||||
for document_id in update_request.document_ids:
|
||||
for doc_chunk_id in all_doc_chunk_ids:
|
||||
processed_update_requests.append(
|
||||
_VespaUpdateRequest(
|
||||
document_id=document_id,
|
||||
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
|
||||
update_request=update_dict,
|
||||
)
|
||||
)
|
||||
index_names = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with httpx.Client(http2=True) as http_client:
|
||||
for update in processed_update_requests:
|
||||
http_client.put(
|
||||
update.url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update.update_request,
|
||||
for index_name in index_names:
|
||||
params = httpx.QueryParams(
|
||||
{
|
||||
"selection": f"{index_name}.document_id=='{normalized_doc_id}'",
|
||||
"cluster": DOCUMENT_INDEX_NAME,
|
||||
}
|
||||
)
|
||||
|
||||
# logger.debug(
|
||||
# "Finished updating Vespa documents in %.2f seconds",
|
||||
# time.monotonic() - update_start,
|
||||
# )
|
||||
total_chunks_updated = 0
|
||||
while True:
|
||||
try:
|
||||
resp = http_client.put(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}",
|
||||
params=params,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update_dict,
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f"Failed to update chunks, details: {e.response.text}"
|
||||
)
|
||||
raise
|
||||
|
||||
resp_data = resp.json()
|
||||
|
||||
if "documentCount" in resp_data:
|
||||
chunks_updated = resp_data["documentCount"]
|
||||
total_chunks_updated += chunks_updated
|
||||
|
||||
# Check for continuation token to handle pagination
|
||||
if "continuation" not in resp_data:
|
||||
break # Exit loop if no continuation token
|
||||
|
||||
if not resp_data["continuation"]:
|
||||
break # Exit loop if continuation token is empty
|
||||
|
||||
params = params.set("continuation", resp_data["continuation"])
|
||||
|
||||
logger.debug(
|
||||
f"VespaIndex.update_single: "
|
||||
f"index={index_name} "
|
||||
f"doc={normalized_doc_id} "
|
||||
f"chunks_deleted={total_chunks_updated}"
|
||||
)
|
||||
return
|
||||
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
@@ -478,6 +482,68 @@ class VespaIndex(DocumentIndex):
|
||||
delete_vespa_docs(
|
||||
document_ids=doc_ids, index_name=index_name, http_client=http_client
|
||||
)
|
||||
return
|
||||
|
||||
def delete_single(self, doc_id: str) -> None:
|
||||
"""Possibly faster overall than the delete method due to using a single
|
||||
delete call with a selection query."""
|
||||
|
||||
# Vespa deletion is poorly documented ... luckily we found this
|
||||
# https://docs.vespa.ai/en/operations/batch-delete.html#example
|
||||
|
||||
doc_id = replace_invalid_doc_id_characters(doc_id)
|
||||
|
||||
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
index_names = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with httpx.Client(http2=True) as http_client:
|
||||
for index_name in index_names:
|
||||
params = httpx.QueryParams(
|
||||
{
|
||||
"selection": f"{index_name}.document_id=='{doc_id}'",
|
||||
"cluster": DOCUMENT_INDEX_NAME,
|
||||
}
|
||||
)
|
||||
|
||||
total_chunks_deleted = 0
|
||||
while True:
|
||||
try:
|
||||
resp = http_client.delete(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}",
|
||||
params=params,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f"Failed to delete chunk, details: {e.response.text}"
|
||||
)
|
||||
raise
|
||||
|
||||
resp_data = resp.json()
|
||||
|
||||
if "documentCount" in resp_data:
|
||||
chunks_deleted = resp_data["documentCount"]
|
||||
total_chunks_deleted += chunks_deleted
|
||||
|
||||
# Check for continuation token to handle pagination
|
||||
if "continuation" not in resp_data:
|
||||
break # Exit loop if no continuation token
|
||||
|
||||
if not resp_data["continuation"]:
|
||||
break # Exit loop if continuation token is empty
|
||||
|
||||
params = params.set("continuation", resp_data["continuation"])
|
||||
|
||||
logger.debug(
|
||||
f"VespaIndex.delete_single: "
|
||||
f"index={index_name} "
|
||||
f"doc={doc_id} "
|
||||
f"chunks_deleted={total_chunks_deleted}"
|
||||
)
|
||||
return
|
||||
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from danswer.configs.app_configs import DYNAMIC_CONFIG_STORE
|
||||
from danswer.dynamic_configs.interface import DynamicConfigStore
|
||||
from danswer.dynamic_configs.store import FileSystemBackedDynamicConfigStore
|
||||
from danswer.dynamic_configs.store import PostgresBackedDynamicConfigStore
|
||||
|
||||
|
||||
def get_dynamic_config_store() -> DynamicConfigStore:
|
||||
dynamic_config_store_type = DYNAMIC_CONFIG_STORE
|
||||
if dynamic_config_store_type == FileSystemBackedDynamicConfigStore.__name__:
|
||||
raise NotImplementedError("File based config store no longer supported")
|
||||
if dynamic_config_store_type == PostgresBackedDynamicConfigStore.__name__:
|
||||
return PostgresBackedDynamicConfigStore()
|
||||
|
||||
# TODO: change exception type
|
||||
raise Exception("Unknown dynamic config store type")
|
||||
@@ -1,102 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from filelock import FileLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session_factory
|
||||
from danswer.db.models import KVStore
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.interface import DynamicConfigStore
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
|
||||
|
||||
FILE_LOCK_TIMEOUT = 10
|
||||
|
||||
|
||||
def _get_file_lock(file_name: Path) -> FileLock:
|
||||
return FileLock(file_name.with_suffix(".lock"))
|
||||
|
||||
|
||||
class FileSystemBackedDynamicConfigStore(DynamicConfigStore):
|
||||
def __init__(self, dir_path: str) -> None:
|
||||
# TODO (chris): maybe require all possible keys to be passed in
|
||||
# at app start somehow to prevent key overlaps
|
||||
self.dir_path = Path(dir_path)
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
file_path = self.dir_path / key
|
||||
lock = _get_file_lock(file_path)
|
||||
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
|
||||
with open(file_path, "w+") as f:
|
||||
json.dump(val, f)
|
||||
|
||||
def load(self, key: str) -> JSON_ro:
|
||||
file_path = self.dir_path / key
|
||||
if not file_path.exists():
|
||||
raise ConfigNotFoundError
|
||||
lock = _get_file_lock(file_path)
|
||||
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
|
||||
with open(self.dir_path / key) as f:
|
||||
return cast(JSON_ro, json.load(f))
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
file_path = self.dir_path / key
|
||||
if not file_path.exists():
|
||||
raise ConfigNotFoundError
|
||||
lock = _get_file_lock(file_path)
|
||||
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
class PostgresBackedDynamicConfigStore(DynamicConfigStore):
|
||||
@contextmanager
|
||||
def get_session(self) -> Iterator[Session]:
|
||||
factory = get_session_factory()
|
||||
session: Session = factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
# The actual encryption/decryption is done in Postgres, we just need to choose
|
||||
# which field to set
|
||||
encrypted_val = val if encrypt else None
|
||||
plain_val = val if not encrypt else None
|
||||
with self.get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
if obj:
|
||||
obj.value = plain_val
|
||||
obj.encrypted_value = encrypted_val
|
||||
else:
|
||||
obj = KVStore(
|
||||
key=key, value=plain_val, encrypted_value=encrypted_val
|
||||
) # type: ignore
|
||||
session.query(KVStore).filter_by(key=key).delete() # just in case
|
||||
session.add(obj)
|
||||
session.commit()
|
||||
|
||||
def load(self, key: str) -> JSON_ro:
|
||||
with self.get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
if not obj:
|
||||
raise ConfigNotFoundError
|
||||
|
||||
if obj.value is not None:
|
||||
return cast(JSON_ro, obj.value)
|
||||
if obj.encrypted_value is not None:
|
||||
return cast(JSON_ro, obj.encrypted_value)
|
||||
|
||||
return None
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
with self.get_session() as session:
|
||||
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
|
||||
if result == 0:
|
||||
raise ConfigNotFoundError
|
||||
session.commit()
|
||||
@@ -20,6 +20,8 @@ from pypdf.errors import PdfStreamError
|
||||
|
||||
from danswer.configs.constants import DANSWER_METADATA_FILENAME
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -331,9 +333,10 @@ def file_io_to_text(file: IO[Any]) -> str:
|
||||
|
||||
|
||||
def extract_file_text(
|
||||
file_name: str | None,
|
||||
file: IO[Any],
|
||||
file_name: str,
|
||||
break_on_unprocessable: bool = True,
|
||||
extension: str | None = None,
|
||||
) -> str:
|
||||
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
|
||||
".pdf": pdf_to_text,
|
||||
@@ -345,22 +348,29 @@ def extract_file_text(
|
||||
".html": parse_html_page_basic,
|
||||
}
|
||||
|
||||
def _process_file() -> str:
|
||||
if file_name:
|
||||
extension = get_file_ext(file_name)
|
||||
if check_file_ext_is_valid(extension):
|
||||
return extension_to_function.get(extension, file_io_to_text)(file)
|
||||
try:
|
||||
if get_unstructured_api_key():
|
||||
return unstructured_to_text(file, file_name)
|
||||
|
||||
# Either the file somehow has no name or the extension is not one that we are familiar with
|
||||
if file_name or extension:
|
||||
if extension is not None:
|
||||
final_extension = extension
|
||||
elif file_name is not None:
|
||||
final_extension = get_file_ext(file_name)
|
||||
|
||||
if check_file_ext_is_valid(final_extension):
|
||||
return extension_to_function.get(final_extension, file_io_to_text)(file)
|
||||
|
||||
# Either the file somehow has no name or the extension is not one that we recognize
|
||||
if is_text_file(file):
|
||||
return file_io_to_text(file)
|
||||
|
||||
raise ValueError("Unknown file extension and unknown text encoding")
|
||||
|
||||
try:
|
||||
return _process_file()
|
||||
except Exception as e:
|
||||
if break_on_unprocessable:
|
||||
raise RuntimeError(f"Failed to process file: {str(e)}") from e
|
||||
logger.warning(f"Failed to process file: {str(e)}")
|
||||
raise RuntimeError(
|
||||
f"Failed to process file {file_name or 'Unknown'}: {str(e)}"
|
||||
) from e
|
||||
logger.warning(f"Failed to process file {file_name or 'Unknown'}: {str(e)}")
|
||||
return ""
|
||||
|
||||
67
backend/danswer/file_processing/unstructured.py
Normal file
67
backend/danswer/file_processing/unstructured.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
|
||||
from unstructured.staging.base import dict_to_elements
|
||||
from unstructured_client import UnstructuredClient # type: ignore
|
||||
from unstructured_client.models import operations # type: ignore
|
||||
from unstructured_client.models import shared
|
||||
|
||||
from danswer.configs.constants import KV_UNSTRUCTURED_API_KEY
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_unstructured_api_key() -> str | None:
|
||||
kv_store = get_kv_store()
|
||||
try:
|
||||
return cast(str, kv_store.load(KV_UNSTRUCTURED_API_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def update_unstructured_api_key(api_key: str) -> None:
|
||||
kv_store = get_kv_store()
|
||||
kv_store.store(KV_UNSTRUCTURED_API_KEY, api_key)
|
||||
|
||||
|
||||
def delete_unstructured_api_key() -> None:
|
||||
kv_store = get_kv_store()
|
||||
kv_store.delete(KV_UNSTRUCTURED_API_KEY)
|
||||
|
||||
|
||||
def _sdk_partition_request(
|
||||
file: IO[Any], file_name: str, **kwargs: Any
|
||||
) -> operations.PartitionRequest:
|
||||
try:
|
||||
request = operations.PartitionRequest(
|
||||
partition_parameters=shared.PartitionParameters(
|
||||
files=shared.Files(content=file.read(), file_name=file_name),
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
return request
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating partition request for file {file_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
|
||||
logger.debug(f"Starting to read file: {file_name}")
|
||||
req = _sdk_partition_request(file, file_name, strategy="auto")
|
||||
|
||||
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())
|
||||
|
||||
response = unstructured_client.general.partition(req) # type: ignore
|
||||
elements = dict_to_elements(response.elements)
|
||||
|
||||
if response.status_code != 200:
|
||||
err = f"Received unexpected status code {response.status_code} from Unstructured API."
|
||||
logger.error(err)
|
||||
raise ValueError(err)
|
||||
|
||||
return "\n\n".join(str(el) for el in elements)
|
||||
@@ -10,6 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -123,6 +124,7 @@ class Chunker:
|
||||
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
mini_chunk_size: int = MINI_CHUNK_SIZE,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
) -> None:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
@@ -131,6 +133,7 @@ class Chunker:
|
||||
self.enable_multipass = enable_multipass
|
||||
self.enable_large_chunks = enable_large_chunks
|
||||
self.tokenizer = tokenizer
|
||||
self.heartbeat = heartbeat
|
||||
|
||||
self.blurb_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
@@ -255,7 +258,7 @@ class Chunker:
|
||||
# If the chunk does not have any useable content, it will not be indexed
|
||||
return chunks
|
||||
|
||||
def chunk(self, document: Document) -> list[DocAwareChunk]:
|
||||
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
|
||||
# Specifically for reproducing an issue with gmail
|
||||
if document.source == DocumentSource.GMAIL:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
@@ -302,3 +305,13 @@ class Chunker:
|
||||
normal_chunks.extend(large_chunks)
|
||||
|
||||
return normal_chunks
|
||||
|
||||
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
|
||||
final_chunks: list[DocAwareChunk] = []
|
||||
for document in documents:
|
||||
final_chunks.extend(self._handle_single_document(document))
|
||||
|
||||
if self.heartbeat:
|
||||
self.heartbeat.heartbeat()
|
||||
|
||||
return final_chunks
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
@@ -24,6 +20,9 @@ logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingEmbedder(ABC):
|
||||
"""Converts chunks into chunks with embeddings. Note that one chunk may have
|
||||
multiple embeddings associated with it."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -33,6 +32,7 @@ class IndexingEmbedder(ABC):
|
||||
provider_type: EmbeddingProvider | None,
|
||||
api_key: str | None,
|
||||
api_url: str | None,
|
||||
heartbeat: Heartbeat | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.normalize = normalize
|
||||
@@ -54,6 +54,7 @@ class IndexingEmbedder(ABC):
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
retrim_content=True,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -74,6 +75,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type: EmbeddingProvider | None = None,
|
||||
api_key: str | None = None,
|
||||
api_url: str | None = None,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
model_name,
|
||||
@@ -83,6 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type,
|
||||
api_key,
|
||||
api_url,
|
||||
heartbeat,
|
||||
)
|
||||
|
||||
@log_function_time()
|
||||
@@ -166,7 +169,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
title_embed_dict[title] = title_embedding
|
||||
|
||||
new_embedded_chunk = IndexChunk(
|
||||
**chunk.dict(),
|
||||
**chunk.model_dump(),
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=chunk_embeddings[0],
|
||||
mini_chunk_embeddings=chunk_embeddings[1:],
|
||||
@@ -180,7 +183,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
|
||||
@classmethod
|
||||
def from_db_search_settings(
|
||||
cls, search_settings: SearchSettings
|
||||
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
|
||||
) -> "DefaultIndexingEmbedder":
|
||||
return cls(
|
||||
model_name=search_settings.model_name,
|
||||
@@ -190,28 +193,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model_from_search_settings(
|
||||
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
|
||||
) -> IndexingEmbedder:
|
||||
search_settings: SearchSettings | None
|
||||
if index_model_status == IndexModelStatus.PRESENT:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
elif index_model_status == IndexModelStatus.FUTURE:
|
||||
search_settings = get_secondary_search_settings(db_session)
|
||||
if not search_settings:
|
||||
raise RuntimeError("No secondary index configured")
|
||||
else:
|
||||
raise RuntimeError("Not supporting embedding model rollbacks")
|
||||
|
||||
return DefaultIndexingEmbedder(
|
||||
model_name=search_settings.model_name,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
)
|
||||
|
||||
41
backend/danswer/indexing/indexing_heartbeat.py
Normal file
41
backend/danswer/indexing/indexing_heartbeat.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Heartbeat(abc.ABC):
|
||||
"""Useful for any long-running work that goes through a bunch of items
|
||||
and needs to occasionally give updates on progress.
|
||||
e.g. chunking, embedding, updating vespa, etc."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IndexingHeartbeat(Heartbeat):
|
||||
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
|
||||
self.cnt = 0
|
||||
|
||||
self.index_attempt_id = index_attempt_id
|
||||
self.db_session = db_session
|
||||
self.freq = freq
|
||||
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
self.cnt += 1
|
||||
if self.cnt % self.freq == 0:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=self.db_session, index_attempt_id=self.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
index_attempt.time_updated = func.now()
|
||||
self.db_session.commit()
|
||||
else:
|
||||
logger.error("Index attempt not found, this should not happen!")
|
||||
@@ -31,6 +31,7 @@ from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import DocumentMetadata
|
||||
from danswer.indexing.chunker import Chunker
|
||||
from danswer.indexing.embedder import IndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -220,8 +221,8 @@ def index_doc_batch_prepare(
|
||||
|
||||
document_ids = [document.id for document in documents]
|
||||
db_docs: list[DBDocument] = get_documents_by_ids(
|
||||
document_ids=document_ids,
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
# Skip indexing docs that don't have a newer updated at
|
||||
@@ -283,18 +284,10 @@ def index_doc_batch(
|
||||
return 0, 0
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
chunks: list[DocAwareChunk] = []
|
||||
for document in ctx.updatable_docs:
|
||||
chunks.extend(chunker.chunk(document=document))
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings = (
|
||||
embedder.embed_chunks(
|
||||
chunks=chunks,
|
||||
)
|
||||
if chunks
|
||||
else []
|
||||
)
|
||||
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
|
||||
@@ -406,6 +399,13 @@ def build_indexing_pipeline(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=multipass,
|
||||
enable_large_chunks=enable_large_chunks,
|
||||
# after every doc, update status in case there are a bunch of
|
||||
# really long docs
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=attempt_id, db_session=db_session, freq=1
|
||||
)
|
||||
if attempt_id
|
||||
else None,
|
||||
)
|
||||
|
||||
return partial(
|
||||
|
||||
0
backend/danswer/key_value_store/__init__.py
Normal file
0
backend/danswer/key_value_store/__init__.py
Normal file
7
backend/danswer/key_value_store/factory.py
Normal file
7
backend/danswer/key_value_store/factory.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from danswer.key_value_store.interface import KeyValueStore
|
||||
from danswer.key_value_store.store import PgRedisKVStore
|
||||
|
||||
|
||||
def get_kv_store() -> KeyValueStore:
|
||||
# this is the only one supported currently
|
||||
return PgRedisKVStore()
|
||||
@@ -9,11 +9,11 @@ JSON_ro: TypeAlias = (
|
||||
)
|
||||
|
||||
|
||||
class ConfigNotFoundError(Exception):
|
||||
class KvKeyNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DynamicConfigStore:
|
||||
class KeyValueStore:
|
||||
@abc.abstractmethod
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
raise NotImplementedError
|
||||
99
backend/danswer/key_value_store/store.py
Normal file
99
backend/danswer/key_value_store/store.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session_factory
|
||||
from danswer.db.models import KVStore
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.key_value_store.interface import KeyValueStore
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
REDIS_KEY_PREFIX = "danswer_kv_store:"
|
||||
KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(self) -> None:
|
||||
self.redis_client = get_redis_client()
|
||||
|
||||
@contextmanager
|
||||
def get_session(self) -> Iterator[Session]:
|
||||
factory = get_session_factory()
|
||||
session: Session = factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
# Not encrypted in Redis, but encrypted in Postgres
|
||||
try:
|
||||
self.redis_client.set(
|
||||
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback gracefully to Postgres if Redis fails
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
|
||||
encrypted_val = val if encrypt else None
|
||||
plain_val = val if not encrypt else None
|
||||
with self.get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
if obj:
|
||||
obj.value = plain_val
|
||||
obj.encrypted_value = encrypted_val
|
||||
else:
|
||||
obj = KVStore(
|
||||
key=key, value=plain_val, encrypted_value=encrypted_val
|
||||
) # type: ignore
|
||||
session.query(KVStore).filter_by(key=key).delete() # just in case
|
||||
session.add(obj)
|
||||
session.commit()
|
||||
|
||||
def load(self, key: str) -> JSON_ro:
|
||||
try:
|
||||
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
|
||||
if redis_value:
|
||||
assert isinstance(redis_value, bytes)
|
||||
return json.loads(redis_value.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with self.get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
if not obj:
|
||||
raise KvKeyNotFoundError
|
||||
|
||||
if obj.value is not None:
|
||||
value = obj.value
|
||||
elif obj.encrypted_value is not None:
|
||||
value = obj.encrypted_value
|
||||
else:
|
||||
value = None
|
||||
|
||||
try:
|
||||
self.redis_client.set(REDIS_KEY_PREFIX + key, json.dumps(value))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
|
||||
return cast(JSON_ro, value)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
try:
|
||||
self.redis_client.delete(REDIS_KEY_PREFIX + key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with self.get_session() as session:
|
||||
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
|
||||
if result == 0:
|
||||
raise KvKeyNotFoundError
|
||||
session.commit()
|
||||
@@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
@@ -554,8 +555,7 @@ class Answer:
|
||||
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
yield cast(str, message)
|
||||
for item in stream:
|
||||
for item in itertools.chain([message], stream):
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
|
||||
@@ -290,10 +290,12 @@ class DefaultMultiLLM(LLM):
|
||||
return litellm.completion(
|
||||
# model choice
|
||||
model=f"{self.config.model_provider}/{self.config.model_name}",
|
||||
api_key=self._api_key,
|
||||
base_url=self._api_base,
|
||||
api_version=self._api_version,
|
||||
custom_llm_provider=self._custom_llm_provider,
|
||||
# NOTE: have to pass in None instead of empty string for these
|
||||
# otherwise litellm can have some issues with bedrock
|
||||
api_key=self._api_key or None,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
# actual input
|
||||
messages=prompt,
|
||||
tools=tools,
|
||||
|
||||
@@ -51,6 +51,7 @@ from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.document import check_docs_exist
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.engine import warm_up_connections
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.llm import fetch_default_provider
|
||||
@@ -64,9 +65,9 @@ from danswer.db.search_settings import update_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
@@ -255,7 +256,7 @@ def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
|
||||
|
||||
def translate_saved_search_settings(db_session: Session) -> None:
|
||||
kv_store = get_dynamic_config_store()
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS)
|
||||
@@ -293,17 +294,17 @@ def translate_saved_search_settings(db_session: Session) -> None:
|
||||
logger.notice("Search settings updated and KV store entry deleted.")
|
||||
else:
|
||||
logger.notice("KV store search settings is empty.")
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
logger.notice("No search config found in KV store.")
|
||||
|
||||
|
||||
def mark_reindex_flag(db_session: Session) -> None:
|
||||
kv_store = get_dynamic_config_store()
|
||||
kv_store = get_kv_store()
|
||||
try:
|
||||
value = kv_store.load(KV_REINDEX_KEY)
|
||||
logger.debug(f"Re-indexing flag has value {value}")
|
||||
return
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
# Only need to update the flag if it hasn't been set
|
||||
pass
|
||||
|
||||
@@ -368,7 +369,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
logger.notice("Generative AI Q&A disabled")
|
||||
|
||||
# fill up Postgres connection pools
|
||||
# await warm_up_connections()
|
||||
await warm_up_connections()
|
||||
|
||||
# We cache this at the beginning so there is no delay in the first telemetry
|
||||
get_or_generate_uuid()
|
||||
|
||||
@@ -16,6 +16,7 @@ from danswer.configs.model_configs import (
|
||||
)
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -95,6 +96,7 @@ class EmbeddingModel:
|
||||
api_url: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
retrim_content: bool = False,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
@@ -107,6 +109,7 @@ class EmbeddingModel:
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_name=model_name, provider_type=provider_type
|
||||
)
|
||||
self.heartbeat = heartbeat
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
@@ -166,6 +169,9 @@ class EmbeddingModel:
|
||||
|
||||
response = self._make_model_server_request(embed_request)
|
||||
embeddings.extend(response.embeddings)
|
||||
|
||||
if self.heartbeat:
|
||||
self.heartbeat.heartbeat()
|
||||
return embeddings
|
||||
|
||||
def encode(
|
||||
|
||||
@@ -3,23 +3,23 @@ from typing import Optional
|
||||
|
||||
import redis
|
||||
from redis.client import Redis
|
||||
from redis.connection import ConnectionPool
|
||||
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER
|
||||
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_POOL_MAX_CONNECTIONS
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
from danswer.configs.app_configs import REDIS_SSL
|
||||
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
|
||||
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
|
||||
REDIS_POOL_MAX_CONNECTIONS = 10
|
||||
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
|
||||
class RedisPool:
|
||||
_instance: Optional["RedisPool"] = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_pool: ConnectionPool
|
||||
_pool: redis.BlockingConnectionPool
|
||||
|
||||
def __new__(cls) -> "RedisPool":
|
||||
if not cls._instance:
|
||||
@@ -42,33 +42,52 @@ class RedisPool:
|
||||
db: int = REDIS_DB_NUMBER,
|
||||
password: str = REDIS_PASSWORD,
|
||||
max_connections: int = REDIS_POOL_MAX_CONNECTIONS,
|
||||
ssl_ca_certs: str = REDIS_SSL_CA_CERTS,
|
||||
ssl_ca_certs: str | None = REDIS_SSL_CA_CERTS,
|
||||
ssl_cert_reqs: str = REDIS_SSL_CERT_REQS,
|
||||
ssl: bool = False,
|
||||
) -> redis.ConnectionPool:
|
||||
) -> redis.BlockingConnectionPool:
|
||||
"""We use BlockingConnectionPool because it will block and wait for a connection
|
||||
rather than error if max_connections is reached. This is far more deterministic
|
||||
behavior and aligned with how we want to use Redis."""
|
||||
|
||||
# Using ConnectionPool is not well documented.
|
||||
# Useful examples: https://github.com/redis/redis-py/issues/780
|
||||
if ssl:
|
||||
return redis.ConnectionPool(
|
||||
return redis.BlockingConnectionPool(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
max_connections=max_connections,
|
||||
timeout=None,
|
||||
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
connection_class=redis.SSLConnection,
|
||||
ssl_ca_certs=ssl_ca_certs,
|
||||
ssl_cert_reqs=ssl_cert_reqs,
|
||||
)
|
||||
|
||||
return redis.ConnectionPool(
|
||||
return redis.BlockingConnectionPool(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
max_connections=max_connections,
|
||||
timeout=None,
|
||||
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
)
|
||||
|
||||
|
||||
redis_pool = RedisPool()
|
||||
|
||||
|
||||
def get_redis_client() -> Redis:
|
||||
return redis_pool.get_client()
|
||||
|
||||
|
||||
# # Usage example
|
||||
# redis_pool = RedisPool()
|
||||
# redis_client = redis_pool.get_client()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -17,10 +17,10 @@ def get_kv_search_settings() -> SavedSearchSettings | None:
|
||||
if the value is updated by another process/instance of the API server. If this reads from an in memory cache like
|
||||
reddis then it will be ok. Until then this has some performance implications (though minor)
|
||||
"""
|
||||
kv_store = get_dynamic_config_store()
|
||||
kv_store = get_kv_store()
|
||||
try:
|
||||
return SavedSearchSettings(**cast(dict, kv_store.load(KV_SEARCH_SETTINGS)))
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading search settings: {e}")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
@@ -10,6 +11,8 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.celery_utils import skip_cc_pair_pruning_by_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import remove_credential_from_connector
|
||||
@@ -26,17 +29,22 @@ from danswer.db.index_attempt import count_index_attempts_for_connector
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from danswer.db.models import User
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import CCStatusUpdateRequest
|
||||
from danswer.server.documents.models import CeleryTaskStatus
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairMetadata
|
||||
from danswer.server.documents.models import PaginatedIndexAttempts
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@@ -190,6 +198,181 @@ def update_cc_pair_name(
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/prune")
|
||||
def get_cc_pair_latest_prune(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CeleryTaskStatus:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
# look up the last prune task for this connector (if it exists)
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
if not last_pruning_task:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="No pruning task found.",
|
||||
)
|
||||
|
||||
return CeleryTaskStatus(
|
||||
id=last_pruning_task.task_id,
|
||||
name=last_pruning_task.task_name,
|
||||
status=last_pruning_task.status,
|
||||
start_time=last_pruning_task.start_time,
|
||||
register_time=last_pruning_task.register_time,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/prune")
|
||||
def prune_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[list[int]]:
|
||||
# avoiding circular refs
|
||||
from danswer.background.celery.tasks.pruning.tasks import prune_documents_task
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
if skip_cc_pair_pruning_by_task(
|
||||
last_pruning_task,
|
||||
db_session=db_session,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Pruning task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(f"Pruning the {cc_pair.connector.name} connector.")
|
||||
prune_documents_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the pruning task.",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync")
|
||||
def get_cc_pair_latest_sync(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CeleryTaskStatus:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
# look up the last sync task for this connector (if it exists)
|
||||
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
|
||||
last_sync_task = get_latest_task(sync_task_name, db_session)
|
||||
if not last_sync_task:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="No sync task found.",
|
||||
)
|
||||
|
||||
return CeleryTaskStatus(
|
||||
id=last_sync_task.task_id,
|
||||
name=last_sync_task.task_name,
|
||||
status=last_sync_task.status,
|
||||
start_time=last_sync_task.start_time,
|
||||
register_time=last_sync_task.register_time,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync")
|
||||
def sync_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[list[int]]:
|
||||
# avoiding circular refs
|
||||
from ee.danswer.background.celery.celery_app import (
|
||||
sync_external_doc_permissions_task,
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
|
||||
last_sync_task = get_latest_task(sync_task_name, db_session)
|
||||
|
||||
if last_sync_task and check_task_is_live_and_not_timed_out(
|
||||
last_sync_task, db_session
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Sync task already in progress.",
|
||||
)
|
||||
if skip_cc_pair_pruning_by_task(
|
||||
last_sync_task,
|
||||
db_session=db_session,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair_id),
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the sync task.",
|
||||
)
|
||||
|
||||
|
||||
@router.put("/connector/{connector_id}/credential/{credential_id}")
|
||||
def associate_credential_to_connector(
|
||||
connector_id: int,
|
||||
|
||||
@@ -74,8 +74,8 @@ from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.server.documents.models import AuthStatus
|
||||
from danswer.server.documents.models import AuthUrl
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
@@ -116,7 +116,7 @@ def check_google_app_gmail_credentials_exist(
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
return {"client_id": get_google_app_gmail_cred().web.client_id}
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Google App Credentials not found")
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ def delete_google_app_gmail_credentials(
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
delete_google_app_gmail_cred()
|
||||
except ConfigNotFoundError as e:
|
||||
except KvKeyNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
@@ -154,7 +154,7 @@ def check_google_app_credentials_exist(
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
return {"client_id": get_google_app_cred().web.client_id}
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Google App Credentials not found")
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ def delete_google_app_credentials(
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
delete_google_app_cred()
|
||||
except ConfigNotFoundError as e:
|
||||
except KvKeyNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
@@ -192,7 +192,7 @@ def check_google_service_gmail_account_key_exist(
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
return {"service_account_email": get_gmail_service_account_key().client_email}
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Google Service Account Key not found"
|
||||
)
|
||||
@@ -218,7 +218,7 @@ def delete_google_service_gmail_account_key(
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
delete_gmail_service_account_key()
|
||||
except ConfigNotFoundError as e:
|
||||
except KvKeyNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
@@ -232,7 +232,7 @@ def check_google_service_account_key_exist(
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
return {"service_account_email": get_service_account_key().client_email}
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Google Service Account Key not found"
|
||||
)
|
||||
@@ -258,7 +258,7 @@ def delete_google_service_account_key(
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
delete_service_account_key()
|
||||
except ConfigNotFoundError as e:
|
||||
except KvKeyNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
@@ -280,7 +280,7 @@ def upsert_service_account_credential(
|
||||
DocumentSource.GOOGLE_DRIVE,
|
||||
delegated_user_email=service_account_credential_request.google_drive_delegated_user,
|
||||
)
|
||||
except ConfigNotFoundError as e:
|
||||
except KvKeyNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# first delete all existing service account credentials
|
||||
@@ -306,7 +306,7 @@ def upsert_gmail_service_account_credential(
|
||||
DocumentSource.GMAIL,
|
||||
delegated_user_email=service_account_credential_request.gmail_delegated_user,
|
||||
)
|
||||
except ConfigNotFoundError as e:
|
||||
except KvKeyNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# first delete all existing service account credentials
|
||||
@@ -781,6 +781,7 @@ def connector_run_once(
|
||||
detail="Connector has no valid credentials, cannot create index attempts.",
|
||||
)
|
||||
|
||||
# Prevents index attempts for cc pairs that already have an index attempt currently running
|
||||
skipped_credentials = [
|
||||
credential_id
|
||||
for credential_id in credential_ids
|
||||
@@ -790,15 +791,15 @@ def connector_run_once(
|
||||
credential_id=credential_id,
|
||||
),
|
||||
only_current=True,
|
||||
disinclude_finished=True,
|
||||
db_session=db_session,
|
||||
disinclude_finished=True,
|
||||
)
|
||||
]
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
connector_credential_pairs = [
|
||||
get_connector_credential_pair(run_info.connector_id, credential_id, db_session)
|
||||
get_connector_credential_pair(connector_id, credential_id, db_session)
|
||||
for credential_id in credential_ids
|
||||
if credential_id not in skipped_credentials
|
||||
]
|
||||
|
||||
@@ -268,6 +268,14 @@ class CCPairFullInfo(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CeleryTaskStatus(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
status: TaskStatus
|
||||
start_time: datetime | None
|
||||
register_time: datetime | None
|
||||
|
||||
|
||||
class FailedConnectorIndexingStatus(BaseModel):
|
||||
"""Simplified version of ConnectorIndexingStatus for failed indexing attempts"""
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -28,9 +29,9 @@ from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.utils import test_llm
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
@@ -113,7 +114,7 @@ def validate_existing_genai_api_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
# Only validate every so often
|
||||
kv_store = get_dynamic_config_store()
|
||||
kv_store = get_kv_store()
|
||||
curr_time = datetime.now(tz=timezone.utc)
|
||||
try:
|
||||
last_check = datetime.fromtimestamp(
|
||||
@@ -122,7 +123,7 @@ def validate_existing_genai_api_key(
|
||||
check_freq_sec = timedelta(seconds=GENERATIVE_MODEL_ACCESS_CHECK_FREQ)
|
||||
if curr_time - last_check < check_freq_sec:
|
||||
return
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
# First time checking the key, nothing unusual
|
||||
pass
|
||||
|
||||
@@ -146,10 +147,6 @@ def create_deletion_attempt_for_connector_id(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
from danswer.background.celery.celery_app import (
|
||||
check_for_connector_deletion_task,
|
||||
)
|
||||
|
||||
connector_id = connector_credential_pair_identifier.connector_id
|
||||
credential_id = connector_credential_pair_identifier.credential_id
|
||||
|
||||
@@ -193,8 +190,11 @@ def create_deletion_attempt_for_connector_id(
|
||||
status=ConnectorCredentialPairStatus.DELETING,
|
||||
)
|
||||
|
||||
# run the beat task to pick up this deletion early
|
||||
check_for_connector_deletion_task.apply_async(
|
||||
db_session.commit()
|
||||
|
||||
# run the beat task to pick up this deletion from the db immediately
|
||||
celery_app.send_task(
|
||||
"check_for_connector_deletion_task",
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.llm import fetch_provider
|
||||
from danswer.db.llm import remove_llm_provider
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
@@ -124,17 +125,26 @@ def list_llm_providers(
|
||||
def put_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
is_creation: bool = Query(
|
||||
True,
|
||||
False,
|
||||
description="True if updating an existing provider, False if creating a new one",
|
||||
),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullLLMProvider:
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = fetch_provider(db_session, llm_provider.name)
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider.name} already exists",
|
||||
)
|
||||
|
||||
try:
|
||||
return upsert_llm_provider(
|
||||
llm_provider=llm_provider,
|
||||
db_session=db_session,
|
||||
is_creation=is_creation,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to upsert LLM Provider")
|
||||
|
||||
@@ -21,6 +21,9 @@ from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_current_search_settings
|
||||
from danswer.db.search_settings import update_search_settings_status
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_processing.unstructured import delete_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import update_unstructured_api_key
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.search.models import SearchSettingsCreationRequest
|
||||
@@ -30,7 +33,6 @@ from danswer.server.models import IdReturn
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import ALT_INDEX_SUFFIX
|
||||
|
||||
|
||||
router = APIRouter(prefix="/search-settings")
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -196,3 +198,27 @@ def update_saved_search_settings(
|
||||
update_current_search_settings(
|
||||
search_settings=search_settings, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
@router.get("/unstructured-api-key-set")
|
||||
def unstructured_api_key_set(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> bool:
|
||||
api_key = get_unstructured_api_key()
|
||||
print(api_key)
|
||||
return api_key is not None
|
||||
|
||||
|
||||
@router.put("/upsert-unstructured-api-key")
|
||||
def upsert_unstructured_api_key(
|
||||
unstructured_api_key: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
update_unstructured_api_key(unstructured_api_key)
|
||||
|
||||
|
||||
@router.delete("/delete-unstructured-api-key")
|
||||
def delete_unstructured_api_key_endpoint(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
delete_unstructured_api_key()
|
||||
|
||||
@@ -18,7 +18,7 @@ from danswer.db.slack_bot_config import fetch_slack_bot_configs
|
||||
from danswer.db.slack_bot_config import insert_slack_bot_config
|
||||
from danswer.db.slack_bot_config import remove_slack_bot_config
|
||||
from danswer.db.slack_bot_config import update_slack_bot_config
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.server.manage.models import SlackBotConfig
|
||||
from danswer.server.manage.models import SlackBotConfigCreationRequest
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
@@ -212,5 +212,5 @@ def put_tokens(
|
||||
def get_tokens(_: User | None = Depends(current_admin_user)) -> SlackBotTokens:
|
||||
try:
|
||||
return fetch_tokens()
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="No tokens found")
|
||||
|
||||
@@ -38,7 +38,7 @@ from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.db.users import list_users
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.manage.models import AllUsersResponse
|
||||
from danswer.server.manage.models import UserByEmail
|
||||
from danswer.server.manage.models import UserInfo
|
||||
@@ -367,7 +367,7 @@ def verify_user_logged_in(
|
||||
# if auth type is disabled, return a dummy user with preferences from
|
||||
# the key-value store
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_dynamic_config_store()
|
||||
store = get_kv_store()
|
||||
return fetch_no_auth_user(store)
|
||||
|
||||
raise HTTPException(
|
||||
@@ -405,7 +405,7 @@ def update_user_default_model(
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_dynamic_config_store()
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
no_auth_user.preferences.default_model = request.default_model
|
||||
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
||||
@@ -433,7 +433,7 @@ def update_user_assistant_list(
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_dynamic_config_store()
|
||||
store = get_kv_store()
|
||||
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
no_auth_user.preferences.chosen_assistants = request.chosen_assistants
|
||||
@@ -487,7 +487,7 @@ def update_user_assistant_visibility(
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_dynamic_config_store()
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
preferences = no_auth_user.preferences
|
||||
updated_preferences = update_assistant_list(preferences, assistant_id, show)
|
||||
|
||||
@@ -588,7 +588,10 @@ def upload_files_for_chat(
|
||||
# if the file is a doc, extract text and store that so we don't need
|
||||
# to re-extract it every time we send a message
|
||||
if file_type == ChatFileType.DOC:
|
||||
extracted_text = extract_file_text(file_name=file.filename, file=file.file)
|
||||
extracted_text = extract_file_text(
|
||||
file=file.file,
|
||||
file_name=file.filename or "",
|
||||
)
|
||||
text_file_id = str(uuid.uuid4())
|
||||
file_store.save_file(
|
||||
file_name=text_file_id,
|
||||
|
||||
@@ -18,7 +18,7 @@ from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.tag import get_tags_by_value_prefix_for_source_types
|
||||
from danswer.db.tag import find_tags
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.one_shot_answer.answer_question import stream_search_answer
|
||||
@@ -99,12 +99,25 @@ def get_tags(
|
||||
if not allow_prefix:
|
||||
raise NotImplementedError("Cannot disable prefix match for now")
|
||||
|
||||
db_tags = get_tags_by_value_prefix_for_source_types(
|
||||
tag_key_prefix=match_pattern,
|
||||
tag_value_prefix=match_pattern,
|
||||
key_prefix = match_pattern
|
||||
value_prefix = match_pattern
|
||||
require_both_to_match = False
|
||||
|
||||
# split on = to allow the user to type in "author=bob"
|
||||
EQUAL_PAT = "="
|
||||
if match_pattern and EQUAL_PAT in match_pattern:
|
||||
split_pattern = match_pattern.split(EQUAL_PAT)
|
||||
key_prefix = split_pattern[0]
|
||||
value_prefix = EQUAL_PAT.join(split_pattern[1:])
|
||||
require_both_to_match = True
|
||||
|
||||
db_tags = find_tags(
|
||||
tag_key_prefix=key_prefix,
|
||||
tag_value_prefix=value_prefix,
|
||||
sources=sources,
|
||||
limit=limit,
|
||||
db_session=db_session,
|
||||
require_both_to_match=require_both_to_match,
|
||||
)
|
||||
server_tags = [
|
||||
SourceTag(
|
||||
|
||||
@@ -19,8 +19,8 @@ from danswer.db.notification import dismiss_notification
|
||||
from danswer.db.notification import get_notification_by_id
|
||||
from danswer.db.notification import get_notifications
|
||||
from danswer.db.notification import update_notification_last_shown
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.server.settings.models import Notification
|
||||
from danswer.server.settings.models import Settings
|
||||
from danswer.server.settings.models import UserSettings
|
||||
@@ -58,9 +58,9 @@ def fetch_settings(
|
||||
user_notifications = get_user_notifications(user, db_session)
|
||||
|
||||
try:
|
||||
kv_store = get_dynamic_config_store()
|
||||
kv_store = get_kv_store()
|
||||
needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
needs_reindexing = False
|
||||
|
||||
return UserSettings(
|
||||
@@ -97,7 +97,7 @@ def get_user_notifications(
|
||||
# Reindexing flag should only be shown to admins, basic users can't trigger it anyway
|
||||
return []
|
||||
|
||||
kv_store = get_dynamic_config_store()
|
||||
kv_store = get_kv_store()
|
||||
try:
|
||||
needs_index = cast(bool, kv_store.load(KV_REINDEX_KEY))
|
||||
if not needs_index:
|
||||
@@ -105,7 +105,7 @@ def get_user_notifications(
|
||||
notif_type=NotificationType.REINDEX, db_session=db_session
|
||||
)
|
||||
return []
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
# If something goes wrong and the flag is gone, better to not start a reindexing
|
||||
# it's a heavyweight long running job and maybe this flag is cleaned up later
|
||||
logger.warning("Could not find reindex flag")
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_SETTINGS_KEY
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.server.settings.models import Settings
|
||||
|
||||
|
||||
def load_settings() -> Settings:
|
||||
dynamic_config_store = get_dynamic_config_store()
|
||||
dynamic_config_store = get_kv_store()
|
||||
try:
|
||||
settings = Settings(**cast(dict, dynamic_config_store.load(KV_SETTINGS_KEY)))
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
settings = Settings()
|
||||
dynamic_config_store.store(KV_SETTINGS_KEY, settings.model_dump())
|
||||
|
||||
@@ -18,4 +18,4 @@ def load_settings() -> Settings:
|
||||
|
||||
|
||||
def store_settings(settings: Settings) -> None:
|
||||
get_dynamic_config_store().store(KV_SETTINGS_KEY, settings.model_dump())
|
||||
get_kv_store().store(KV_SETTINGS_KEY, settings.model_dump())
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
|
||||
@@ -9,7 +9,7 @@ from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.headers import build_llm_extra_headers
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
@@ -10,7 +10,7 @@ from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import message_to_string
|
||||
|
||||
@@ -16,7 +16,7 @@ from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import ContextualPruningConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
|
||||
@@ -2,7 +2,7 @@ import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
@@ -12,8 +12,8 @@ from danswer.configs.constants import KV_CUSTOMER_UUID_KEY
|
||||
from danswer.configs.constants import KV_INSTANCE_DOMAIN_KEY
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
|
||||
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.danswer.ai/anonymous_telemetry"
|
||||
_CACHED_UUID: str | None = None
|
||||
@@ -34,11 +34,11 @@ def get_or_generate_uuid() -> str:
|
||||
if _CACHED_UUID is not None:
|
||||
return _CACHED_UUID
|
||||
|
||||
kv_store = get_dynamic_config_store()
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
_CACHED_UUID = str(uuid.uuid4())
|
||||
kv_store.store(KV_CUSTOMER_UUID_KEY, _CACHED_UUID, encrypt=True)
|
||||
|
||||
@@ -51,11 +51,11 @@ def _get_or_generate_instance_domain() -> str | None:
|
||||
if _CACHED_INSTANCE_DOMAIN is not None:
|
||||
return _CACHED_INSTANCE_DOMAIN
|
||||
|
||||
kv_store = get_dynamic_config_store()
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
first_user = db_session.query(User).first()
|
||||
if first_user:
|
||||
|
||||
@@ -11,12 +11,25 @@ from danswer.server.settings.store import load_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.danswer.background.celery_utils import should_perform_external_permissions_check
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_doc_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_group_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_permission_sync_entrypoint,
|
||||
run_external_doc_permission_sync,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_external_group_permission_sync,
|
||||
)
|
||||
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
||||
|
||||
@@ -26,11 +39,18 @@ logger = setup_logger()
|
||||
global_version.set_ee()
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_permissions_task)
|
||||
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_permissions_task(cc_pair_id: int) -> None:
|
||||
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
run_permission_sync_entrypoint(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@@ -44,18 +64,35 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_sync_external_permissions_task",
|
||||
name="check_sync_external_doc_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_permissions_task() -> None:
|
||||
def check_sync_external_doc_permissions_task() -> None:
|
||||
"""Runs periodically to sync external permissions"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_permissions_check(
|
||||
if should_perform_external_doc_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_permissions_task.apply_async(
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_sync_external_group_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_group_permissions_task() -> None:
|
||||
"""Runs periodically to sync external group permissions"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_group_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_group_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
||||
)
|
||||
|
||||
@@ -94,9 +131,13 @@ def autogenerate_usage_report_task() -> None:
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"sync-external-permissions": {
|
||||
"task": "check_sync_external_permissions_task",
|
||||
"schedule": timedelta(seconds=60), # TODO: optimize this
|
||||
"sync-external-doc-permissions": {
|
||||
"task": "check_sync_external_doc_permissions_task",
|
||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||
},
|
||||
"sync-external-group-permissions": {
|
||||
"task": "check_sync_external_group_permissions_task",
|
||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||
},
|
||||
"autogenerate_usage_report": {
|
||||
"task": "autogenerate_usage_report_task",
|
||||
|
||||
52
backend/ee/danswer/background/celery/tasks/vespa/tasks.py
Normal file
52
backend/ee/danswer/background/celery/tasks/vespa/tasks.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import cast
|
||||
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.user_group import delete_user_group
|
||||
from ee.danswer.db.user_group import fetch_user_group
|
||||
from ee.danswer.db.user_group import mark_user_group_as_synced
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
key = key_bytes.decode("utf-8")
|
||||
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
|
||||
if not usergroup_id:
|
||||
task_logger.warning("Could not parse usergroup id from {key}")
|
||||
return
|
||||
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
fence_value = r.get(rug.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rug.taskset_key))
|
||||
task_logger.info(
|
||||
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
||||
if user_group:
|
||||
if user_group.is_up_for_deletion:
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
||||
|
||||
r.delete(rug.taskset_key)
|
||||
r.delete(rug.fence_key)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user