mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
285 Commits
final_grap
...
a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f47d6798e1 | ||
|
|
8f67f1715c | ||
|
|
3b365509e2 | ||
|
|
022cbdfccf | ||
|
|
ebec6f6b10 | ||
|
|
1cad9c7b3d | ||
|
|
b4e975013c | ||
|
|
dd26f92206 | ||
|
|
4d00ec45ad | ||
|
|
1a81c67a67 | ||
|
|
04f965e656 | ||
|
|
277d37e0ee | ||
|
|
3cd260131b | ||
|
|
ad21ee0e9a | ||
|
|
c7dc0e9af0 | ||
|
|
75c5de802b | ||
|
|
c39f590d0d | ||
|
|
82a9fda846 | ||
|
|
842d4ab2a8 | ||
|
|
cddcec4ea4 | ||
|
|
09dd7b424c | ||
|
|
a2fd8d5e0a | ||
|
|
802dc00f78 | ||
|
|
f745ca1e03 | ||
|
|
eaaa135f90 | ||
|
|
457e7992a4 | ||
|
|
2fb1d06fbf | ||
|
|
8f9d4335ce | ||
|
|
ee1cb084ac | ||
|
|
2c77ad2aab | ||
|
|
f7d77a3c76 | ||
|
|
8b220d2dba | ||
|
|
6913efef90 | ||
|
|
12cbbe6cee | ||
|
|
55de519364 | ||
|
|
36134021c5 | ||
|
|
5b78299880 | ||
|
|
59364aadd7 | ||
|
|
e12785d277 | ||
|
|
7906d9edc8 | ||
|
|
6e54c97326 | ||
|
|
61424de531 | ||
|
|
4c2cf8b132 | ||
|
|
b169f78699 | ||
|
|
e48086b1c2 | ||
|
|
6b8ecb3a4b | ||
|
|
deb66a88aa | ||
|
|
90bd535c48 | ||
|
|
0de487064a | ||
|
|
114326d11a | ||
|
|
389c7b72db | ||
|
|
28ad01a51a | ||
|
|
0c102ebb5c | ||
|
|
5063b944ec | ||
|
|
15afe4dc78 | ||
|
|
a159779d39 | ||
|
|
44ebe3ae31 | ||
|
|
938a65628d | ||
|
|
5d390b65eb | ||
|
|
33974fc12c | ||
|
|
db0779dd02 | ||
|
|
f3fb7c572e | ||
|
|
0a0215ceee | ||
|
|
1a9921f63e | ||
|
|
a385234c0e | ||
|
|
65573210f1 | ||
|
|
c148fa5bfa | ||
|
|
11372aac8f | ||
|
|
f23a89ccfd | ||
|
|
e022e77b6d | ||
|
|
02cc211e91 | ||
|
|
bfe963988e | ||
|
|
0e6c2f0b51 | ||
|
|
98e88e2715 | ||
|
|
da46f61123 | ||
|
|
aa5be37f97 | ||
|
|
efe2e79f27 | ||
|
|
6f9740d026 | ||
|
|
dee197570d | ||
|
|
f8a7749b46 | ||
|
|
494fda906d | ||
|
|
89eaa8bc30 | ||
|
|
9537a2581e | ||
|
|
3ccd951307 | ||
|
|
ba712d447d | ||
|
|
a9bcc89a2c | ||
|
|
ded42e2036 | ||
|
|
86ecf8e0fc | ||
|
|
b393af676c | ||
|
|
26bdb41e8f | ||
|
|
3365e0b16e | ||
|
|
40dc4708d2 | ||
|
|
20df20ae51 | ||
|
|
7eafdae17f | ||
|
|
301032f59e | ||
|
|
b75b8334a6 | ||
|
|
d25de6e1cb | ||
|
|
d892203821 | ||
|
|
35d32ea3b0 | ||
|
|
1581d35476 | ||
|
|
1f4fe42f4b | ||
|
|
101b010c5c | ||
|
|
b212b228fb | ||
|
|
85d5e6c02f | ||
|
|
f40c5ca9bd | ||
|
|
9be54a2b4c | ||
|
|
b4417fabd7 | ||
|
|
2d74d44538 | ||
|
|
30d17ef9ee | ||
|
|
804de3248e | ||
|
|
1cbc067483 | ||
|
|
6c0a0b6454 | ||
|
|
ca88100f38 | ||
|
|
7c9f605a99 | ||
|
|
fbf09c7859 | ||
|
|
28fe0d12ca | ||
|
|
d403840507 | ||
|
|
174dabf52f | ||
|
|
03807688e6 | ||
|
|
8bbf5053de | ||
|
|
d6b4c08d24 | ||
|
|
af8e361fc2 | ||
|
|
7ce276bbe1 | ||
|
|
95df136104 | ||
|
|
6b57e68226 | ||
|
|
cbd4481838 | ||
|
|
80343d6d75 | ||
|
|
d5b9a6e552 | ||
|
|
10f221cd37 | ||
|
|
f83e6806b6 | ||
|
|
8f61505437 | ||
|
|
a47d27de6c | ||
|
|
aa187c86e2 | ||
|
|
c72c5619f0 | ||
|
|
78e7710f17 | ||
|
|
672f5cc5ce | ||
|
|
7b3c433ff8 | ||
|
|
057321a59f | ||
|
|
5cc46341f7 | ||
|
|
21a3921790 | ||
|
|
3586f9b565 | ||
|
|
aa69fe762b | ||
|
|
3ef72b8d1a | ||
|
|
a0124e4e50 | ||
|
|
a52485bda2 | ||
|
|
79d37156c6 | ||
|
|
6fa8fabb47 | ||
|
|
4214a3a6e2 | ||
|
|
1a3469d2c5 | ||
|
|
30dc408028 | ||
|
|
5d356cc971 | ||
|
|
e4c7cfde42 | ||
|
|
1900a390d8 | ||
|
|
150dcc2883 | ||
|
|
3404c7eb1d | ||
|
|
64909d74f9 | ||
|
|
83bc7d4656 | ||
|
|
3206bb27ce | ||
|
|
f189eda904 | ||
|
|
7aaf822430 | ||
|
|
0ff5180d7b | ||
|
|
089c734f63 | ||
|
|
0da736bed9 | ||
|
|
e00f4678df | ||
|
|
e56fd43ba6 | ||
|
|
28e65669b4 | ||
|
|
493c3d7314 | ||
|
|
b04e9e9b67 | ||
|
|
3755e575a5 | ||
|
|
63655cfbed | ||
|
|
7f788e4b1e | ||
|
|
1362d4b583 | ||
|
|
4f47004d47 | ||
|
|
3fdd233e84 | ||
|
|
0c54d9d57d | ||
|
|
c2088602e1 | ||
|
|
b3c367d09c | ||
|
|
457d32fef0 | ||
|
|
af187c6cfe | ||
|
|
a0235b7b7b | ||
|
|
a30de693cb | ||
|
|
07aeea69e7 | ||
|
|
bd40328a73 | ||
|
|
b8232e0681 | ||
|
|
fffb9c155a | ||
|
|
f513c5bbed | ||
|
|
9a4e51a18e | ||
|
|
2f2fc08553 | ||
|
|
c68c6fdc44 | ||
|
|
834c76e30a | ||
|
|
ec02665ffa | ||
|
|
3fa1b18306 | ||
|
|
c9bdf4c443 | ||
|
|
e229d27734 | ||
|
|
140c5b3957 | ||
|
|
3e511497d2 | ||
|
|
b0056907fb | ||
|
|
728a41a35a | ||
|
|
ef8dda2d47 | ||
|
|
15283b3140 | ||
|
|
e159b2e947 | ||
|
|
9155800fab | ||
|
|
a392ef0541 | ||
|
|
5679f0af61 | ||
|
|
ff8db71cb5 | ||
|
|
1cff2b82fd | ||
|
|
50dd3c8beb | ||
|
|
66a459234d | ||
|
|
19e57474dc | ||
|
|
f9638f2ea5 | ||
|
|
fbf51b70d0 | ||
|
|
b97cc01bb2 | ||
|
|
6d48fd5d99 | ||
|
|
1f61447b4b | ||
|
|
deee2b3513 | ||
|
|
b73d66c84a | ||
|
|
c5a61f4820 | ||
|
|
ea4a3cbf86 | ||
|
|
166514cedf | ||
|
|
be50ae1e71 | ||
|
|
f89504ec53 | ||
|
|
6b3213b1e4 | ||
|
|
48577bf0e4 | ||
|
|
c59d1ff0a5 | ||
|
|
ba38dec592 | ||
|
|
f5adc3063e | ||
|
|
8cfe80c53a | ||
|
|
487250320b | ||
|
|
c8d13922a9 | ||
|
|
cb75449cec | ||
|
|
b66514cd21 | ||
|
|
77650c9ee3 | ||
|
|
316b6b99ea | ||
|
|
34c2aa0860 | ||
|
|
45f67368a2 | ||
|
|
014ba9e220 | ||
|
|
ba64543dd7 | ||
|
|
18c62a0c24 | ||
|
|
33f555922c | ||
|
|
05f6f6d5b5 | ||
|
|
19dae1d870 | ||
|
|
6d859bd37c | ||
|
|
122e3fa3fa | ||
|
|
87b542b335 | ||
|
|
00229d2abe | ||
|
|
5f2644985c | ||
|
|
c82a36ad68 | ||
|
|
16d1c19d9f | ||
|
|
9f179940f8 | ||
|
|
8a8e2b310e | ||
|
|
2274cab554 | ||
|
|
ef104e9a82 | ||
|
|
a575d7f1eb | ||
|
|
f404c4b448 | ||
|
|
3884f1d70a | ||
|
|
bc9d5fece7 | ||
|
|
bb279a8580 | ||
|
|
a9403016c9 | ||
|
|
f3cea79c1c | ||
|
|
54bb79303c | ||
|
|
d3dfabb20e | ||
|
|
7d1ec1095c | ||
|
|
f531d071af | ||
|
|
4218814385 | ||
|
|
e662e3b57d | ||
|
|
2073820e33 | ||
|
|
5f25b243c5 | ||
|
|
a9427f190a | ||
|
|
18fbe9d7e8 | ||
|
|
75c9b1cafe | ||
|
|
632a8f700b | ||
|
|
cd58c96014 | ||
|
|
c5032d25c9 | ||
|
|
72acde6fd4 | ||
|
|
5596a68d08 | ||
|
|
5b18409c89 | ||
|
|
84272af5ac | ||
|
|
6bef70c8b7 | ||
|
|
7f7559e3d2 | ||
|
|
7ba829a585 | ||
|
|
8b2ecb4eab | ||
|
|
2dd3870504 | ||
|
|
df464fc54b | ||
|
|
96b98fbc4a | ||
|
|
66cf67d04d |
47
.github/actions/custom-build-and-push/action.yml
vendored
47
.github/actions/custom-build-and-push/action.yml
vendored
@@ -32,16 +32,20 @@ inputs:
|
||||
description: 'Cache destinations'
|
||||
required: false
|
||||
retry-wait-time:
|
||||
description: 'Time to wait before retry in seconds'
|
||||
description: 'Time to wait before attempt 2 in seconds'
|
||||
required: false
|
||||
default: '5'
|
||||
default: '60'
|
||||
retry-wait-time-2:
|
||||
description: 'Time to wait before attempt 3 in seconds'
|
||||
required: false
|
||||
default: '120'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build and push Docker image (First Attempt)
|
||||
- name: Build and push Docker image (Attempt 1 of 3)
|
||||
id: buildx1
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
continue-on-error: true
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
@@ -54,16 +58,17 @@ runs:
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
- name: Wait to retry
|
||||
- name: Wait before attempt 2
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
run: |
|
||||
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
|
||||
sleep ${{ inputs.retry-wait-time }}
|
||||
shell: bash
|
||||
|
||||
- name: Build and push Docker image (Retry Attempt)
|
||||
- name: Build and push Docker image (Attempt 2 of 3)
|
||||
id: buildx2
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
@@ -74,3 +79,31 @@ runs:
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
- name: Wait before attempt 3
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
|
||||
run: |
|
||||
echo "Second attempt failed. Waiting ${{ inputs.retry-wait-time-2 }} seconds before retry..."
|
||||
sleep ${{ inputs.retry-wait-time-2 }}
|
||||
shell: bash
|
||||
|
||||
- name: Build and push Docker image (Attempt 3 of 3)
|
||||
id: buildx3
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
- name: Report failure
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success' && steps.buildx3.outcome != 'success'
|
||||
run: |
|
||||
echo "All attempts failed. Possible transient infrastucture issues? Try again later or inspect logs for details."
|
||||
shell: bash
|
||||
|
||||
@@ -7,16 +7,17 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# TODO: make this a matrix build like the web containers
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -27,6 +28,11 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
@@ -36,12 +42,20 @@ jobs:
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.REGISTRY_IMAGE }}:latest
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -5,14 +5,18 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-model-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -31,13 +35,21 @@ jobs:
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
danswer/danswer-model-server:${{ github.ref_name }}
|
||||
danswer/danswer-model-server:latest
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
@@ -7,11 +7,15 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
|
||||
- run-id=${{ github.run_id }}
|
||||
- tag=platform-${{ matrix.platform }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -35,7 +39,7 @@ jobs:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -112,8 +116,16 @@ jobs:
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
7
.github/workflows/docker-tag-latest.yml
vendored
7
.github/workflows/docker-tag-latest.yml
vendored
@@ -1,3 +1,6 @@
|
||||
# This workflow is set up to be manually triggered via the GitHub Action tab.
|
||||
# Given a version, it will tag those backend and webserver images as "latest".
|
||||
|
||||
name: Tag Latest Version
|
||||
|
||||
on:
|
||||
@@ -9,7 +12,9 @@ on:
|
||||
|
||||
jobs:
|
||||
tag:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
172
.github/workflows/hotfix-release-branches.yml
vendored
Normal file
172
.github/workflows/hotfix-release-branches.yml
vendored
Normal file
@@ -0,0 +1,172 @@
|
||||
# This workflow is intended to be manually triggered via the GitHub Action tab.
|
||||
# Given a hotfix branch, it will attempt to open a PR to all release branches and
|
||||
# by default auto merge them
|
||||
|
||||
name: Hotfix release branches
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
hotfix_commit:
|
||||
description: 'Hotfix commit hash'
|
||||
required: true
|
||||
hotfix_suffix:
|
||||
description: 'Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})'
|
||||
required: true
|
||||
release_branch_pattern:
|
||||
description: 'Release branch pattern (regex)'
|
||||
required: true
|
||||
default: 'release/.*'
|
||||
auto_merge:
|
||||
description: 'Automatically merge the hotfix PRs'
|
||||
required: true
|
||||
type: choice
|
||||
default: 'true'
|
||||
options:
|
||||
- true
|
||||
- false
|
||||
|
||||
jobs:
|
||||
hotfix_release_branches:
|
||||
permissions: write-all
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
|
||||
# needs RKUO_DEPLOY_KEY for write access to merge PR's
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
|
||||
- name: Fetch All Branches
|
||||
run: |
|
||||
git fetch --all --prune
|
||||
|
||||
- name: Verify Hotfix Commit Exists
|
||||
run: |
|
||||
git rev-parse --verify "${{ github.event.inputs.hotfix_commit }}" || { echo "Commit not found: ${{ github.event.inputs.hotfix_commit }}"; exit 1; }
|
||||
|
||||
- name: Get Release Branches
|
||||
id: get_release_branches
|
||||
run: |
|
||||
BRANCHES=$(git branch -r | grep -E "${{ github.event.inputs.release_branch_pattern }}" | sed 's|origin/||' | tr -d ' ')
|
||||
if [ -z "$BRANCHES" ]; then
|
||||
echo "No release branches found matching pattern '${{ github.event.inputs.release_branch_pattern }}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found release branches:"
|
||||
echo "$BRANCHES"
|
||||
|
||||
# Join the branches into a single line separated by commas
|
||||
BRANCHES_JOINED=$(echo "$BRANCHES" | tr '\n' ',' | sed 's/,$//')
|
||||
|
||||
# Set the branches as an output
|
||||
echo "branches=$BRANCHES_JOINED" >> $GITHUB_OUTPUT
|
||||
|
||||
# notes on all the vagaries of wiring up automated PR's
|
||||
# https://github.com/peter-evans/create-pull-request/blob/main/docs/concepts-guidelines.md#triggering-further-workflow-runs
|
||||
# we must use a custom token for GH_TOKEN to trigger the subsequent PR checks
|
||||
- name: Create and Merge Pull Requests to Matching Release Branches
|
||||
env:
|
||||
HOTFIX_COMMIT: ${{ github.event.inputs.hotfix_commit }}
|
||||
HOTFIX_SUFFIX: ${{ github.event.inputs.hotfix_suffix }}
|
||||
AUTO_MERGE: ${{ github.event.inputs.auto_merge }}
|
||||
GH_TOKEN: ${{ secrets.RKUO_PERSONAL_ACCESS_TOKEN }}
|
||||
run: |
|
||||
# Get the branches from the previous step
|
||||
BRANCHES="${{ steps.get_release_branches.outputs.branches }}"
|
||||
|
||||
# Convert BRANCHES to an array
|
||||
IFS=$',' read -ra BRANCH_ARRAY <<< "$BRANCHES"
|
||||
|
||||
# Loop through each release branch and create and merge a PR
|
||||
for RELEASE_BRANCH in "${BRANCH_ARRAY[@]}"; do
|
||||
echo "Processing $RELEASE_BRANCH..."
|
||||
|
||||
# Parse out the release version by removing "release/" from the branch name
|
||||
RELEASE_VERSION=${RELEASE_BRANCH#release/}
|
||||
echo "Release version parsed: $RELEASE_VERSION"
|
||||
|
||||
HOTFIX_BRANCH="hotfix/${RELEASE_VERSION}-${HOTFIX_SUFFIX}"
|
||||
echo "Creating PR from $HOTFIX_BRANCH to $RELEASE_BRANCH"
|
||||
|
||||
# Checkout the release branch
|
||||
echo "Checking out $RELEASE_BRANCH"
|
||||
git checkout "$RELEASE_BRANCH"
|
||||
|
||||
# Create the new hotfix branch
|
||||
if git rev-parse --verify "$HOTFIX_BRANCH" >/dev/null 2>&1; then
|
||||
echo "Hotfix branch $HOTFIX_BRANCH already exists. Skipping branch creation."
|
||||
else
|
||||
echo "Branching $RELEASE_BRANCH to $HOTFIX_BRANCH"
|
||||
git checkout -b "$HOTFIX_BRANCH"
|
||||
fi
|
||||
|
||||
# Check if the hotfix commit is a merge commit
|
||||
if git rev-list --merges -n 1 "$HOTFIX_COMMIT" >/dev/null 2>&1; then
|
||||
# -m 1 uses the target branch as the base (which is what we want)
|
||||
echo "Hotfix commit $HOTFIX_COMMIT is a merge commit, using -m 1 for cherry-pick"
|
||||
CHERRY_PICK_CMD="git cherry-pick -m 1 $HOTFIX_COMMIT"
|
||||
else
|
||||
CHERRY_PICK_CMD="git cherry-pick $HOTFIX_COMMIT"
|
||||
fi
|
||||
|
||||
# Perform the cherry-pick
|
||||
echo "Executing: $CHERRY_PICK_CMD"
|
||||
eval "$CHERRY_PICK_CMD"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Cherry-pick failed for $HOTFIX_COMMIT on $HOTFIX_BRANCH. Aborting..."
|
||||
git cherry-pick --abort
|
||||
continue
|
||||
fi
|
||||
|
||||
# Push the hotfix branch to the remote
|
||||
echo "Pushing $HOTFIX_BRANCH..."
|
||||
git push origin "$HOTFIX_BRANCH"
|
||||
echo "Hotfix branch $HOTFIX_BRANCH created and pushed."
|
||||
|
||||
# Check if PR already exists
|
||||
EXISTING_PR=$(gh pr list --head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH" --state open --json number --jq '.[0].number')
|
||||
|
||||
if [ -n "$EXISTING_PR" ]; then
|
||||
echo "An open PR already exists: #$EXISTING_PR. Skipping..."
|
||||
continue
|
||||
fi
|
||||
|
||||
# Create a new PR and capture the output
|
||||
PR_OUTPUT=$(gh pr create --title "Merge $HOTFIX_BRANCH into $RELEASE_BRANCH" \
|
||||
--body "Automated PR to merge \`$HOTFIX_BRANCH\` into \`$RELEASE_BRANCH\`." \
|
||||
--head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH")
|
||||
|
||||
# Extract the URL from the output
|
||||
PR_URL=$(echo "$PR_OUTPUT" | grep -Eo 'https://github.com/[^ ]+')
|
||||
echo "Pull request created: $PR_URL"
|
||||
|
||||
# Extract PR number from URL
|
||||
PR_NUMBER=$(basename "$PR_URL")
|
||||
echo "Pull request created: $PR_NUMBER"
|
||||
|
||||
if [ "$AUTO_MERGE" == "true" ]; then
|
||||
echo "Attempting to merge pull request #$PR_NUMBER"
|
||||
|
||||
# Attempt to merge the PR
|
||||
gh pr merge "$PR_NUMBER" --merge --auto --delete-branch
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Pull request #$PR_NUMBER merged successfully."
|
||||
else
|
||||
# Optionally, handle the error or continue
|
||||
echo "Failed to merge pull request #$PR_NUMBER."
|
||||
fi
|
||||
fi
|
||||
done
|
||||
@@ -1,20 +1,23 @@
|
||||
name: Run Integration Tests
|
||||
name: Run Integration Tests v2
|
||||
concurrency:
|
||||
group: Run-Integration-Tests-${{ github.head_ref }}
|
||||
group: Run-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
runs-on:
|
||||
group: 'arm64-image-builders'
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -28,49 +31,59 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# NOTE: we don't need to build the Web Docker image since it's not used
|
||||
# during the IT for now. We have a separate action to verify it builds
|
||||
# succesfully
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull danswer/danswer-web-server:latest
|
||||
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it
|
||||
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:test
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: danswer/danswer-backend:it
|
||||
cache-from: type=registry,ref=danswer/danswer-backend:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/danswer-backend:it,mode=max
|
||||
type=inline
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: danswer/danswer-model-server:it
|
||||
cache-from: type=registry,ref=danswer/danswer-model-server:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/danswer-model-server:it,mode=max
|
||||
type=inline
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: danswer/integration-test-runner:it
|
||||
cache-from: type=registry,ref=danswer/integration-test-runner:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/integration-test-runner:it,mode=max
|
||||
type=inline
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -79,7 +92,7 @@ jobs:
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=it \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker
|
||||
|
||||
@@ -121,6 +134,7 @@ jobs:
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
@@ -129,7 +143,9 @@ jobs:
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
danswer/integration-test-runner:it
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
@@ -151,7 +167,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
@@ -12,7 +12,8 @@ on:
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
runs-on: Amd64
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
@@ -37,9 +38,9 @@ jobs:
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
17
.github/workflows/pr-python-checks.yml
vendored
17
.github/workflows/pr-python-checks.yml
vendored
@@ -3,18 +3,21 @@ name: Python Checks
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
mypy-check:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
@@ -24,9 +27,9 @@ jobs:
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Run MyPy
|
||||
run: |
|
||||
|
||||
12
.github/workflows/pr-python-connector-tests.yml
vendored
12
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -15,10 +15,14 @@ env:
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
# Jira
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
@@ -28,7 +32,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
@@ -39,8 +43,8 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
|
||||
58
.github/workflows/pr-python-model-tests.yml
vendored
Normal file
58
.github/workflows/pr-python-model-tests.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
env:
|
||||
# Bedrock
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
13
.github/workflows/pr-python-tests.yml
vendored
13
.github/workflows/pr-python-tests.yml
vendored
@@ -3,11 +3,14 @@ name: Python Unit Tests
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
backend-check:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
@@ -18,7 +21,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
@@ -29,8 +32,8 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
|
||||
7
.github/workflows/pr-quality-checks.yml
vendored
7
.github/workflows/pr-quality-checks.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: Quality Checks PR
|
||||
concurrency:
|
||||
group: Quality-Checks-PR-${{ github.head_ref }}
|
||||
group: Quality-Checks-PR-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
@@ -9,7 +9,8 @@ on:
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -17,6 +18,6 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
|
||||
|
||||
54
.github/workflows/tag-nightly.yml
vendored
Normal file
54
.github/workflows/tag-nightly.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
name: Nightly Tag Push
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 10 * * *' # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
|
||||
|
||||
permissions:
|
||||
contents: write # Allows pushing tags to the repository
|
||||
|
||||
jobs:
|
||||
create-and-push-tag:
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
|
||||
- name: Check for existing nightly tag
|
||||
id: check_tag
|
||||
run: |
|
||||
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
|
||||
echo "A tag starting with 'nightly-latest' already exists on HEAD."
|
||||
echo "tag_exists=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No tag starting with 'nightly-latest' exists on HEAD."
|
||||
echo "tag_exists=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# don't tag again if HEAD already has a nightly-latest tag on it
|
||||
- name: Create Nightly Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
env:
|
||||
DATE: ${{ github.run_id }}
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
echo "Creating tag: $TAG_NAME"
|
||||
git tag $TAG_NAME
|
||||
|
||||
- name: Push Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
git push origin $TAG_NAME
|
||||
|
||||
1
.prettierignore
Normal file
1
.prettierignore
Normal file
@@ -0,0 +1 @@
|
||||
backend/tests/integration/tests/pruning/website
|
||||
@@ -22,7 +22,7 @@ Your input is vital to making sure that Danswer moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
And always feel free to message us (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ) /
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
|
||||
* Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
|
||||
## Other Noteable Benefits of Danswer
|
||||
## Other Notable Benefits of Danswer
|
||||
* User Authentication with document level access management.
|
||||
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
* Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
|
||||
@@ -8,10 +8,12 @@ Edition features outside of personal development or testing purposes. Please rea
|
||||
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ARG DANSWER_VERSION=0.8-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
ARG CA_CERT_CONTENT=""
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
@@ -36,11 +38,24 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
|
||||
# Conditionally write the CA certificate and update certificates
|
||||
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
|
||||
echo "Adding custom CA certificate"; \
|
||||
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
|
||||
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
|
||||
update-ca-certificates; \
|
||||
else \
|
||||
echo "No custom CA certificate provided"; \
|
||||
fi
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
--timeout 30 \
|
||||
-r /tmp/requirements.txt \
|
||||
-r /tmp/ee-requirements.txt && \
|
||||
pip uninstall -y py && \
|
||||
@@ -90,6 +105,7 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
@@ -99,7 +115,7 @@ COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connect
|
||||
# Put logo in assets
|
||||
COPY ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
|
||||
@@ -7,7 +7,7 @@ You can find it at https://hub.docker.com/r/danswer/danswer-model-server. For mo
|
||||
visit https://github.com/danswer-ai/danswer."
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ARG DANSWER_VERSION=0.8-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
@@ -15,7 +15,10 @@ ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
--timeout 30 \
|
||||
-r /tmp/requirements.txt
|
||||
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
@@ -52,6 +55,6 @@ COPY ./shared_configs /app/shared_configs
|
||||
# Model Server main code
|
||||
COPY ./model_server /app/model_server
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 27 KiB |
@@ -1,6 +1,6 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
[DEFAULT]
|
||||
# path to migration scripts
|
||||
script_location = alembic
|
||||
|
||||
@@ -47,7 +47,8 @@ prepend_sys_path = .
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
version_path_separator = os
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
@@ -106,3 +107,12 @@ formatter = generic
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
|
||||
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
version_locations = %(script_location)s/versions
|
||||
|
||||
[schema_private]
|
||||
script_location = alembic_tenants
|
||||
version_locations = %(script_location)s/versions
|
||||
|
||||
@@ -1,107 +1,198 @@
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Any
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import Base
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
from danswer.background.celery.celery_app import get_all_tenant_ids
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
# Add your model's MetaData object here for 'autogenerate' support
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
Excludes specified tables from migrations.
|
||||
"""
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
"""
|
||||
url = build_connection_string()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
Parses command-line options passed via '-x' in Alembic commands.
|
||||
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
|
||||
"""
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
for pair in arg.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
x_args[key.strip()] = value.strip()
|
||||
schema_name = x_args.get("schema", "public")
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
if MULTI_TENANT and schema_name == "public":
|
||||
raise ValueError(
|
||||
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
|
||||
"Please specify a tenant-specific schema."
|
||||
)
|
||||
|
||||
return schema_name, create_schema, upgrade_all_tenants
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
"""
|
||||
Executes migrations in the specified schema.
|
||||
"""
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
# Set search_path to the target schema
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
) # type: ignore
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
Determines whether to run migrations for a single schema or all schemas,
|
||||
and executes migrations accordingly.
|
||||
"""
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
|
||||
connectable = create_async_engine(
|
||||
engine = create_async_engine(
|
||||
build_connection_string(),
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
if upgrade_all_tenants:
|
||||
# Run migrations for all tenant schemas sequentially
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
|
||||
await connectable.dispose()
|
||||
for schema in tenant_schemas:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema_name,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema_name}: {e}")
|
||||
raise
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""
|
||||
Run migrations in 'offline' mode.
|
||||
"""
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
# Run offline migrations for all tenant schemas
|
||||
engine = create_async_engine(url)
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
engine.sync_engine.dispose()
|
||||
|
||||
for schema in tenant_schemas:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
else:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
|
||||
"""
|
||||
Runs migrations in 'online' mode using an asynchronous engine.
|
||||
"""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add additional data to notifications
|
||||
|
||||
Revision ID: 1b10e1fda030
|
||||
Revises: 6756efa39ada
|
||||
Create Date: 2024-10-15 19:26:44.071259
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1b10e1fda030"
|
||||
down_revision = "6756efa39ada"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"notification", sa.Column("additional_data", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("notification", "additional_data")
|
||||
@@ -0,0 +1,102 @@
|
||||
"""add_user_delete_cascades
|
||||
|
||||
Revision ID: 1b8206b29c5d
|
||||
Revises: 35e6853a51d5
|
||||
Create Date: 2024-09-18 11:48:59.418726
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1b8206b29c5d"
|
||||
down_revision = "35e6853a51d5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"credential_user_id_fkey",
|
||||
"credential",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_session_user_id_fkey",
|
||||
"chat_session",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_folder_user_id_fkey",
|
||||
"chat_folder",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
|
||||
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"notification_user_id_fkey",
|
||||
"notification",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"inputprompt_user_id_fkey",
|
||||
"inputprompt",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"credential_user_id_fkey", "credential", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_session_user_id_fkey", "chat_session", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_folder_user_id_fkey", "chat_folder", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
|
||||
op.create_foreign_key("prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"])
|
||||
|
||||
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"notification_user_id_fkey", "notification", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"inputprompt_user_id_fkey", "inputprompt", "user", ["user_id"], ["id"]
|
||||
)
|
||||
@@ -0,0 +1,64 @@
|
||||
"""server default chosen assistants
|
||||
|
||||
Revision ID: 35e6853a51d5
|
||||
Revises: c99d76fcd298
|
||||
Create Date: 2024-09-13 13:20:32.885317
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35e6853a51d5"
|
||||
down_revision = "c99d76fcd298"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
DEFAULT_ASSISTANTS = [-2, -1, 0]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Update any NULL values to the default value
|
||||
# This upgrades existing users without ordered assistant
|
||||
# to have default assistants set to visible assistants which are
|
||||
# accessible by them.
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user" u
|
||||
SET chosen_assistants = (
|
||||
SELECT jsonb_agg(
|
||||
p.id ORDER BY
|
||||
COALESCE(p.display_priority, 2147483647) ASC,
|
||||
p.id ASC
|
||||
)
|
||||
FROM persona p
|
||||
LEFT JOIN persona__user pu ON p.id = pu.persona_id AND pu.user_id = u.id
|
||||
WHERE p.is_visible = true
|
||||
AND (p.is_public = true OR pu.user_id IS NOT NULL)
|
||||
)
|
||||
WHERE chosen_assistants IS NULL
|
||||
OR chosen_assistants = 'null'
|
||||
OR jsonb_typeof(chosen_assistants) = 'null'
|
||||
OR (jsonb_typeof(chosen_assistants) = 'string' AND chosen_assistants = '"null"')
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 2: Alter the column to make it non-nullable
|
||||
op.alter_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default=sa.text(f"'{DEFAULT_ASSISTANTS}'::jsonb"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
server_default=None,
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
"""fix_user__external_user_group_id_fk
|
||||
|
||||
Revision ID: 46b7a812670f
|
||||
Revises: f32615f71aeb
|
||||
Create Date: 2024-09-23 12:58:03.894038
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "46b7a812670f"
|
||||
down_revision = "f32615f71aeb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing primary key
|
||||
op.drop_constraint(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
type_="primary",
|
||||
)
|
||||
|
||||
# Add the new composite primary key
|
||||
op.create_primary_key(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
["user_id", "external_user_group_id", "cc_pair_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the composite primary key
|
||||
op.drop_constraint(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
type_="primary",
|
||||
)
|
||||
# Delete all entries from the table
|
||||
op.execute("DELETE FROM user__external_user_group_id")
|
||||
|
||||
# Recreate the original primary key on user_id
|
||||
op.create_primary_key(
|
||||
"user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id"]
|
||||
)
|
||||
@@ -1,65 +0,0 @@
|
||||
"""single tool call per message
|
||||
|
||||
|
||||
Revision ID: 4e8e7ae58189
|
||||
Revises: 5c7fdadae813
|
||||
Create Date: 2024-09-09 10:07:58.008838
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4e8e7ae58189"
|
||||
down_revision = "5c7fdadae813"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the new column
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_tool_call",
|
||||
"chat_message",
|
||||
"tool_call",
|
||||
["tool_call_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Migrate existing data
|
||||
op.execute(
|
||||
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
|
||||
)
|
||||
|
||||
# Drop the old relationship
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.drop_column("tool_call", "message_id")
|
||||
|
||||
# Add a unique constraint to ensure one-to-one relationship
|
||||
op.create_unique_constraint(
|
||||
"uq_chat_message_tool_call_id", "chat_message", ["tool_call_id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old column
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
|
||||
)
|
||||
|
||||
# Migrate data back
|
||||
op.execute(
|
||||
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
|
||||
)
|
||||
|
||||
# Drop the new column
|
||||
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "tool_call_id")
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Add last synced and last modified to document table
|
||||
|
||||
Revision ID: 52a219fb5233
|
||||
Revises: f17bf3b0d9f1
|
||||
Revises: f7e58d357687
|
||||
Create Date: 2024-08-28 17:40:46.077470
|
||||
|
||||
"""
|
||||
|
||||
79
backend/alembic/versions/55546a7967ee_assistant_rework.py
Normal file
79
backend/alembic/versions/55546a7967ee_assistant_rework.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""assistant_rework
|
||||
|
||||
Revision ID: 55546a7967ee
|
||||
Revises: 61ff3651add4
|
||||
Create Date: 2024-09-18 17:00:23.755399
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "55546a7967ee"
|
||||
down_revision = "61ff3651add4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Reworking persona and user tables for new assistant features
|
||||
# keep track of user's chosen assistants separate from their `ordering`
|
||||
op.add_column("persona", sa.Column("builtin_persona", sa.Boolean(), nullable=True))
|
||||
op.execute("UPDATE persona SET builtin_persona = default_persona")
|
||||
op.alter_column("persona", "builtin_persona", nullable=False)
|
||||
op.drop_index("_default_persona_name_idx", table_name="persona")
|
||||
op.create_index(
|
||||
"_builtin_persona_name_idx",
|
||||
"persona",
|
||||
["name"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("builtin_persona = true"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"user", sa.Column("visible_assistants", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user", sa.Column("hidden_assistants", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE \"user\" SET visible_assistants = '[]'::jsonb, hidden_assistants = '[]'::jsonb"
|
||||
)
|
||||
op.alter_column(
|
||||
"user",
|
||||
"visible_assistants",
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
)
|
||||
op.alter_column(
|
||||
"user",
|
||||
"hidden_assistants",
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
)
|
||||
op.drop_column("persona", "default_persona")
|
||||
op.add_column(
|
||||
"persona", sa.Column("is_default_persona", sa.Boolean(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverting changes made in upgrade
|
||||
op.drop_column("user", "hidden_assistants")
|
||||
op.drop_column("user", "visible_assistants")
|
||||
op.drop_index("_builtin_persona_name_idx", table_name="persona")
|
||||
|
||||
op.drop_column("persona", "is_default_persona")
|
||||
op.add_column("persona", sa.Column("default_persona", sa.Boolean(), nullable=True))
|
||||
op.execute("UPDATE persona SET default_persona = builtin_persona")
|
||||
op.alter_column("persona", "default_persona", nullable=False)
|
||||
op.drop_column("persona", "builtin_persona")
|
||||
op.create_index(
|
||||
"_default_persona_name_idx",
|
||||
"persona",
|
||||
["name"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("default_persona = true"),
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add api_version and deployment_name to search settings
|
||||
|
||||
Revision ID: 5d12a446f5c0
|
||||
Revises: e4334d5b33ba
|
||||
Create Date: 2024-10-08 15:56:07.975636
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5d12a446f5c0"
|
||||
down_revision = "e4334d5b33ba"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("api_version", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("deployment_name", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("embedding_provider", "deployment_name")
|
||||
op.drop_column("embedding_provider", "api_version")
|
||||
162
backend/alembic/versions/61ff3651add4_add_permission_syncing.py
Normal file
162
backend/alembic/versions/61ff3651add4_add_permission_syncing.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Add Permission Syncing
|
||||
|
||||
Revision ID: 61ff3651add4
|
||||
Revises: 1b8206b29c5d
|
||||
Create Date: 2024-09-05 13:57:11.770413
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "61ff3651add4"
|
||||
down_revision = "1b8206b29c5d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Admin user who set up connectors will lose access to the docs temporarily
|
||||
# only way currently to give back access is to rerun from beginning
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"access_type",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false"
|
||||
)
|
||||
op.alter_column("connector_credential_pair", "access_type", nullable=False)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"auto_sync_options",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.drop_column("connector_credential_pair", "is_public")
|
||||
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"user__external_user_group_id",
|
||||
sa.Column(
|
||||
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
|
||||
),
|
||||
sa.Column("external_user_group_id", sa.String(), nullable=False),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("user_id"),
|
||||
)
|
||||
|
||||
op.drop_column("external_permission", "user_id")
|
||||
op.drop_column("email_to_external_user_cache", "user_id")
|
||||
op.drop_table("permission_sync_run")
|
||||
op.drop_table("external_permission")
|
||||
op.drop_table("email_to_external_user_cache")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_public", sa.BOOLEAN(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')"
|
||||
)
|
||||
op.alter_column("connector_credential_pair", "is_public", nullable=False)
|
||||
|
||||
op.drop_column("connector_credential_pair", "auto_sync_options")
|
||||
op.drop_column("connector_credential_pair", "access_type")
|
||||
op.drop_column("connector_credential_pair", "last_time_perm_sync")
|
||||
op.drop_column("document", "external_user_emails")
|
||||
op.drop_column("document", "external_user_group_ids")
|
||||
op.drop_column("document", "is_public")
|
||||
|
||||
op.drop_table("user__external_user_group_id")
|
||||
|
||||
# Drop the enum type at the end of the downgrade
|
||||
op.create_table(
|
||||
"permission_sync_run",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("update_type", sa.String(), nullable=False),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["cc_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"external_permission",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("external_permission_group", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"email_to_external_user_cache",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_user_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Revision ID: 6756efa39ada
|
||||
Revises: 5d12a446f5c0
|
||||
Create Date: 2024-10-15 17:47:44.108537
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "6756efa39ada"
|
||||
down_revision = "5d12a446f5c0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
"""
|
||||
Migrate chat_session and chat_message tables to use UUID primary keys.
|
||||
|
||||
This script:
|
||||
1. Adds UUID columns to chat_session and chat_message
|
||||
2. Populates new columns with UUIDs
|
||||
3. Updates foreign key relationships
|
||||
4. Removes old integer ID columns
|
||||
|
||||
Note: Downgrade will assign new integer IDs, not restore original ones.
|
||||
"""
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
|
||||
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column(
|
||||
"new_id",
|
||||
sa.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
op.execute("UPDATE chat_session SET new_id = gen_random_uuid();")
|
||||
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("new_chat_session_id", sa.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET new_chat_session_id = cs.new_id
|
||||
FROM chat_session cs
|
||||
WHERE chat_message.chat_session_id = cs.id;
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.drop_column("chat_message", "chat_session_id")
|
||||
op.alter_column(
|
||||
"chat_message", "new_chat_session_id", new_column_name="chat_session_id"
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
|
||||
op.drop_column("chat_session", "id")
|
||||
op.alter_column("chat_session", "new_id", new_column_name="id")
|
||||
|
||||
op.create_primary_key("chat_session_pkey", "chat_session", ["id"])
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("old_id", sa.Integer, autoincrement=True, nullable=True),
|
||||
)
|
||||
|
||||
op.execute("CREATE SEQUENCE chat_session_old_id_seq OWNED BY chat_session.old_id;")
|
||||
op.execute(
|
||||
"ALTER TABLE chat_session ALTER COLUMN old_id SET DEFAULT nextval('chat_session_old_id_seq');"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"UPDATE chat_session SET old_id = nextval('chat_session_old_id_seq') WHERE old_id IS NULL;"
|
||||
)
|
||||
|
||||
op.alter_column("chat_session", "old_id", nullable=False)
|
||||
|
||||
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
|
||||
op.create_primary_key("chat_session_pkey", "chat_session", ["old_id"])
|
||||
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("old_chat_session_id", sa.Integer, nullable=True),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET old_chat_session_id = cs.old_id
|
||||
FROM chat_session cs
|
||||
WHERE chat_message.chat_session_id = cs.id;
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("chat_message", "chat_session_id")
|
||||
op.alter_column(
|
||||
"chat_message", "old_chat_session_id", new_column_name="chat_session_id"
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["old_id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_column("chat_session", "id")
|
||||
op.alter_column("chat_session", "old_id", new_column_name="id")
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"id",
|
||||
type_=sa.Integer(),
|
||||
existing_type=sa.Integer(),
|
||||
existing_nullable=False,
|
||||
existing_server_default=False,
|
||||
)
|
||||
|
||||
# Rename the sequence
|
||||
op.execute("ALTER SEQUENCE chat_session_old_id_seq RENAME TO chat_session_id_seq;")
|
||||
|
||||
# Update the default value to use the renamed sequence
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"id",
|
||||
server_default=sa.text("nextval('chat_session_id_seq'::regclass)"),
|
||||
)
|
||||
@@ -9,7 +9,7 @@ import json
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "703313b75876"
|
||||
@@ -54,9 +54,7 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
try:
|
||||
settings_json = cast(
|
||||
str, get_dynamic_config_store().load("token_budget_settings")
|
||||
)
|
||||
settings_json = cast(str, get_kv_store().load("token_budget_settings"))
|
||||
settings = json.loads(settings_json)
|
||||
|
||||
is_enabled = settings.get("enable_token_budget", False)
|
||||
@@ -71,7 +69,7 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
# Delete the dynamic config
|
||||
get_dynamic_config_store().delete("token_budget_settings")
|
||||
get_kv_store().delete("token_budget_settings")
|
||||
|
||||
except Exception:
|
||||
# Ignore if the dynamic config is not found
|
||||
|
||||
27
backend/alembic/versions/797089dfb4d2_persona_start_date.py
Normal file
27
backend/alembic/versions/797089dfb4d2_persona_start_date.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""persona_start_date
|
||||
|
||||
Revision ID: 797089dfb4d2
|
||||
Revises: 55546a7967ee
|
||||
Create Date: 2024-09-11 14:51:49.785835
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "797089dfb4d2"
|
||||
down_revision = "55546a7967ee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("search_start_date", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "search_start_date")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add last_pruned to the connector_credential_pair table
|
||||
|
||||
Revision ID: ac5eaac849f9
|
||||
Revises: 52a219fb5233
|
||||
Create Date: 2024-09-10 15:04:26.437118
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ac5eaac849f9"
|
||||
down_revision = "46b7a812670f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# last pruned represents the last time the connector was pruned
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("last_pruned", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "last_pruned")
|
||||
@@ -0,0 +1,43 @@
|
||||
"""non nullable default persona
|
||||
|
||||
Revision ID: bd2921608c3a
|
||||
Revises: 797089dfb4d2
|
||||
Create Date: 2024-09-20 10:28:37.992042
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bd2921608c3a"
|
||||
down_revision = "797089dfb4d2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Set existing NULL values to False
|
||||
op.execute(
|
||||
"UPDATE persona SET is_default_persona = FALSE WHERE is_default_persona IS NULL"
|
||||
)
|
||||
|
||||
# Alter the column to be not nullable with a default value of False
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"is_default_persona",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert the changes
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"is_default_persona",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default=None,
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add nullable to persona id in Chat Session
|
||||
|
||||
Revision ID: c99d76fcd298
|
||||
Revises: 5c7fdadae813
|
||||
Create Date: 2024-07-09 19:27:01.579697
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c99d76fcd298"
|
||||
down_revision = "5c7fdadae813"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -20,7 +20,7 @@ depends_on: None = None
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
sa.text('select id, chosen_assistants from "user"')
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
@@ -37,7 +37,7 @@ def upgrade() -> None:
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
|
||||
),
|
||||
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
|
||||
)
|
||||
@@ -46,7 +46,7 @@ def upgrade() -> None:
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
sa.text('select id, chosen_assistants from "user"')
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
@@ -59,7 +59,7 @@ def downgrade() -> None:
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
|
||||
),
|
||||
{"chosen_assistants": chosen_assistants, "id": id},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add_deployment_name_to_llmprovider
|
||||
|
||||
Revision ID: e4334d5b33ba
|
||||
Revises: ac5eaac849f9
|
||||
Create Date: 2024-10-04 09:52:34.896867
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e4334d5b33ba"
|
||||
down_revision = "ac5eaac849f9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider", sa.Column("deployment_name", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "deployment_name")
|
||||
@@ -1,7 +1,7 @@
|
||||
"""standard answer match_regex flag
|
||||
|
||||
Revision ID: efb35676026c
|
||||
Revises: 52a219fb5233
|
||||
Revises: 0ebb1d516877
|
||||
Create Date: 2024-09-11 13:55:46.101149
|
||||
|
||||
"""
|
||||
@@ -19,7 +19,9 @@ def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"standard_answer",
|
||||
sa.Column("match_regex", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"match_regex", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add custom headers to tools
|
||||
|
||||
Revision ID: f32615f71aeb
|
||||
Revises: bd2921608c3a
|
||||
Create Date: 2024-09-12 20:26:38.932377
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f32615f71aeb"
|
||||
down_revision = "bd2921608c3a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool", sa.Column("custom_headers", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool", "custom_headers")
|
||||
@@ -1,7 +1,7 @@
|
||||
"""add has_web_login column to user
|
||||
|
||||
Revision ID: f7e58d357687
|
||||
Revises: bceb1e139447
|
||||
Revises: ba98eba0f66a
|
||||
Create Date: 2024-09-07 20:20:54.522620
|
||||
|
||||
"""
|
||||
|
||||
3
backend/alembic_tenants/README.md
Normal file
3
backend/alembic_tenants/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
These files are for public table migrations when operating with multi tenancy.
|
||||
|
||||
If you are not a Danswer developer, you can ignore this directory entirely.
|
||||
111
backend/alembic_tenants/env.py
Normal file
111
backend/alembic_tenants/env.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
from alembic import context
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import PublicBase
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = [PublicBase.metadata]
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = build_connection_string()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
) # type: ignore
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
|
||||
connectable = create_async_engine(
|
||||
build_connection_string(),
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
24
backend/alembic_tenants/script.py.mako
Normal file
24
backend/alembic_tenants/script.py.mako
Normal file
@@ -0,0 +1,24 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,24 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "14a83a331951"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"user_tenant_mapping",
|
||||
sa.Column("email", sa.String(), nullable=False),
|
||||
sa.Column("tenant_id", sa.String(), nullable=False),
|
||||
sa.UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
|
||||
sa.UniqueConstraint("email", name="uq_email"),
|
||||
schema="public",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("user_tenant_mapping", schema="public")
|
||||
@@ -1,7 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_user_email
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_access_info_for_document
|
||||
from danswer.db.document import get_access_info_for_documents
|
||||
@@ -18,10 +18,13 @@ def _get_access_for_document(
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
if not info:
|
||||
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
|
||||
return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2])
|
||||
return DocumentAccess.build(
|
||||
user_emails=info[1] if info and info[1] else [],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=info[2] if info else False,
|
||||
)
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
@@ -34,6 +37,16 @@ def get_access_for_document(
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
return DocumentAccess(
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
is_public=False,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
)
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
@@ -42,13 +55,27 @@ def _get_access_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(
|
||||
user_ids=user_ids, user_groups=[], is_public=is_public
|
||||
doc_access = {
|
||||
document_id: DocumentAccess(
|
||||
user_emails=set([email for email in user_emails if email]),
|
||||
# MIT version will wipe all groups and external groups on update
|
||||
user_groups=set(),
|
||||
is_public=is_public,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
for document_id, user_emails, is_public in document_access_info
|
||||
}
|
||||
|
||||
# Sometimes the document has not be indexed by the indexing job yet, in those cases
|
||||
# the document does not exist and so we use least permissive. Specifically the EE version
|
||||
# checks the MIT version permissions and creates a superset. This ensures that this flow
|
||||
# does not fail even if the Document has not yet been indexed.
|
||||
for doc_id in document_ids:
|
||||
if doc_id not in doc_access:
|
||||
doc_access[doc_id] = get_null_document_access()
|
||||
return doc_access
|
||||
|
||||
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
@@ -70,7 +97,7 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
matches one entry in the returned set.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
return {PUBLIC_DOC_PAT}
|
||||
|
||||
|
||||
|
||||
@@ -1,30 +1,72 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_external_group
|
||||
from danswer.access.utils import prefix_user_email
|
||||
from danswer.access.utils import prefix_user_group
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess:
|
||||
user_ids: set[str] # stringified UUIDs
|
||||
user_groups: set[str] # names of user groups associated with this document
|
||||
class ExternalAccess:
|
||||
# Emails of external users with access to the doc externally
|
||||
external_user_emails: set[str]
|
||||
# Names or external IDs of groups with access to the doc
|
||||
external_user_group_ids: set[str]
|
||||
# Whether the document is public in the external system or Danswer
|
||||
is_public: bool
|
||||
|
||||
def to_acl(self) -> list[str]:
|
||||
return (
|
||||
[prefix_user(user_id) for user_id in self.user_ids]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Danswer users, None indicates admin
|
||||
user_emails: set[str | None]
|
||||
# Names of user groups associated with this document
|
||||
user_groups: set[str]
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
return set(
|
||||
[
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.user_emails
|
||||
if user_email
|
||||
]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ [
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.external_user_emails
|
||||
]
|
||||
+ [
|
||||
# The group names are already prefixed by the source type
|
||||
# This adds an additional prefix of "external_group:"
|
||||
prefix_external_group(group_name)
|
||||
for group_name in self.external_user_group_ids
|
||||
]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
|
||||
cls,
|
||||
user_emails: list[str | None],
|
||||
user_groups: list[str],
|
||||
external_user_emails: list[str],
|
||||
external_user_group_ids: list[str],
|
||||
is_public: bool,
|
||||
) -> "DocumentAccess":
|
||||
return cls(
|
||||
user_ids={str(user_id) for user_id in user_ids if user_id},
|
||||
external_user_emails={
|
||||
prefix_user_email(external_email)
|
||||
for external_email in external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
prefix_external_group(external_group_id)
|
||||
for external_group_id in external_user_group_ids
|
||||
},
|
||||
user_emails={
|
||||
prefix_user_email(user_email)
|
||||
for user_email in user_emails
|
||||
if user_email
|
||||
},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
def prefix_user_email(user_email: str) -> str:
|
||||
"""Prefixes a user email to eliminate collision with group names.
|
||||
This applies to both a Danswer user and an External user, this is to make the query time
|
||||
more efficient"""
|
||||
return f"user_email:{user_email}"
|
||||
|
||||
|
||||
def prefix_user_group(user_group_name: str) -> str:
|
||||
"""Prefixes a user group name to eliminate collision with user IDs.
|
||||
"""Prefixes a user group name to eliminate collision with user emails.
|
||||
This assumes that user ids are prefixed with a different prefix."""
|
||||
return f"group:{user_group_name}"
|
||||
|
||||
|
||||
def prefix_external_group(ext_group_name: str) -> str:
|
||||
"""Prefixes an external group name to eliminate collision with user emails / Danswer groups."""
|
||||
return f"external_group:{ext_group_name}"
|
||||
|
||||
|
||||
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""External groups may collide across sources, every source needs its own prefix."""
|
||||
return f"{source.value.upper()}_{ext_group_name}"
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_USER_STORE_KEY
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
|
||||
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_dynamic_config_store()
|
||||
store = get_kv_store()
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
|
||||
def write_invited_users(emails: list[str]) -> int:
|
||||
store = get_dynamic_config_store()
|
||||
store = get_kv_store()
|
||||
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
@@ -4,29 +4,29 @@ from typing import cast
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
|
||||
from danswer.dynamic_configs.store import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.store import DynamicConfigStore
|
||||
from danswer.key_value_store.store import KeyValueStore
|
||||
from danswer.key_value_store.store import KvKeyNotFoundError
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.server.manage.models import UserPreferences
|
||||
|
||||
|
||||
def set_no_auth_user_preferences(
|
||||
store: DynamicConfigStore, preferences: UserPreferences
|
||||
store: KeyValueStore, preferences: UserPreferences
|
||||
) -> None:
|
||||
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
|
||||
|
||||
|
||||
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
|
||||
def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
try:
|
||||
preferences_data = cast(
|
||||
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except ConfigNotFoundError:
|
||||
except KvKeyNotFoundError:
|
||||
return UserPreferences(chosen_assistants=None, default_model=None)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
return UserInfo(
|
||||
id="__no_auth_user__",
|
||||
email="anonymous@danswer.ai",
|
||||
|
||||
@@ -34,6 +34,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
has_web_login: bool | None = True
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
|
||||
@@ -5,17 +5,23 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
from email_validator import EmailUndeliverableError
|
||||
from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import exceptions
|
||||
@@ -25,11 +31,25 @@ from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import JWTStrategy
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.jwt import decode_jwt
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
from fastapi_users.jwt import SecretType
|
||||
from fastapi_users.manager import UserManagerDependency
|
||||
from fastapi_users.openapi import OpenAPIResponseType
|
||||
from fastapi_users.router.common import ErrorCode
|
||||
from fastapi_users.router.common import ErrorModel
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import attributes
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
@@ -39,7 +59,9 @@ from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SMTP_PASS
|
||||
from danswer.configs.app_configs import SMTP_PORT
|
||||
@@ -57,15 +79,21 @@ from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.auth import SQLAlchemyUserAdminDB
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -115,7 +143,10 @@ def verify_email_is_invited(email: str) -> None:
|
||||
if not email:
|
||||
raise PermissionError("Email must be specified")
|
||||
|
||||
email_info = validate_email(email) # can raise EmailNotValidError
|
||||
try:
|
||||
email_info = validate_email(email)
|
||||
except EmailUndeliverableError:
|
||||
raise PermissionError("Email is not valid")
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
@@ -133,8 +164,8 @@ def verify_email_is_invited(email: str) -> None:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
@@ -154,6 +185,20 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return "public"
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
@@ -188,35 +233,83 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0 or user_create.email in get_default_admin_user_emails():
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
return user
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
self.database = tenant_user_db
|
||||
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if (
|
||||
user_count == 0
|
||||
or user_create.email in get_default_admin_user_emails()
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
current_tenant_id.reset(token)
|
||||
return user
|
||||
|
||||
async def on_after_login(
|
||||
self,
|
||||
user: User,
|
||||
request: Request | None = None,
|
||||
response: Response | None = None,
|
||||
) -> None:
|
||||
if response is None or not MULTI_TENANT:
|
||||
return
|
||||
|
||||
tenant_id = get_tenant_id_for_email(user.email)
|
||||
|
||||
tenant_token = jwt.encode(
|
||||
{"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256"
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key="tenant_details",
|
||||
value=tenant_token,
|
||||
httponly=True,
|
||||
secure=WEB_DOMAIN.startswith("https"),
|
||||
samesite="lax",
|
||||
)
|
||||
|
||||
async def oauth_callback(
|
||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||
@@ -231,45 +324,111 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> models.UOAP:
|
||||
verify_email_in_whitelist(account_email)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
user = await super().oauth_callback( # type: ignore
|
||||
oauth_name=oauth_name,
|
||||
access_token=access_token,
|
||||
account_id=account_id,
|
||||
account_email=account_email,
|
||||
expires_at=expires_at,
|
||||
refresh_token=refresh_token,
|
||||
request=request,
|
||||
associate_by_email=associate_by_email,
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": None})
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.has_web_login:
|
||||
await self.user_db.update(
|
||||
user,
|
||||
update_dict={
|
||||
"is_verified": is_verified_by_default,
|
||||
"has_web_login": True,
|
||||
},
|
||||
# Get tenant_id from mapping table
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
|
||||
)
|
||||
user.is_verified = is_verified_by_default
|
||||
user.has_web_login = True
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
return user
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
|
||||
verify_email_in_whitelist(account_email, tenant_id)
|
||||
verify_email_domain(account_email)
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
self.database = tenant_user_db # type: ignore
|
||||
|
||||
oauth_account_dict = {
|
||||
"oauth_name": oauth_name,
|
||||
"access_token": access_token,
|
||||
"account_id": account_id,
|
||||
"account_email": account_email,
|
||||
"expires_at": expires_at,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
try:
|
||||
# Attempt to get user by OAuth account
|
||||
user = await self.get_by_oauth_account(oauth_name, account_id)
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = await self.get_by_email(account_email)
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
|
||||
# If user not found by OAuth account or email, create a new user
|
||||
except exceptions.UserNotExists:
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
"email": account_email,
|
||||
"hashed_password": self.password_helper.hash(password),
|
||||
"is_verified": is_verified_by_default,
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user, existing_oauth_account, oauth_account_dict
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(
|
||||
user, update_dict={"oidc_expiry": oidc_expiry}
|
||||
)
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.has_web_login: # type: ignore
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"has_web_login": True,
|
||||
},
|
||||
)
|
||||
user.is_verified = is_verified_by_default
|
||||
user.has_web_login = True # type: ignore
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if (
|
||||
user.oidc_expiry is not None # type: ignore
|
||||
and not TRACK_EXTERNAL_IDP_EXPIRY
|
||||
):
|
||||
await self.user_db.update(user, {"oidc_expiry": None})
|
||||
user.oidc_expiry = None # type: ignore
|
||||
|
||||
if token:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
@@ -300,18 +459,50 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[User]:
|
||||
user = await super().authenticate(credentials)
|
||||
if user is None:
|
||||
email = credentials.username
|
||||
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
if not tenant_id:
|
||||
# User not found in mapping
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
# Create a tenant-specific session
|
||||
async with get_async_session_with_tenant(tenant_id) as tenant_session:
|
||||
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
|
||||
tenant_session, User
|
||||
)
|
||||
self.user_db = tenant_user_db
|
||||
|
||||
# Proceed with authentication
|
||||
try:
|
||||
user = await self.get_by_email(credentials.username)
|
||||
if not user.has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
user = await self.get_by_email(email)
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
pass
|
||||
return user
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
has_web_login = attributes.get_attribute(user, "has_web_login")
|
||||
|
||||
if not has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
|
||||
verified, updated_password_hash = self.password_helper.verify_and_update(
|
||||
credentials.password, user.hashed_password
|
||||
)
|
||||
if not verified:
|
||||
return None
|
||||
|
||||
if updated_password_hash is not None:
|
||||
await self.user_db.update(
|
||||
user, {"hashed_password": updated_password_hash}
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
@@ -326,21 +517,26 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy:
|
||||
return JWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
strategy = DatabaseStrategy(
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="database",
|
||||
name="jwt" if MULTI_TENANT else "database",
|
||||
transport=cookie_transport,
|
||||
get_strategy=get_database_strategy,
|
||||
)
|
||||
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
@@ -354,9 +550,11 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
This way the login router does not need to be included
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
get_current_user_token = self.authenticator.current_user_token(
|
||||
active=True, verified=requires_verification
|
||||
)
|
||||
|
||||
logout_responses: OpenAPIResponseType = {
|
||||
**{
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
@@ -403,8 +601,8 @@ async def optional_user_(
|
||||
|
||||
async def optional_user(
|
||||
request: Request,
|
||||
user: User | None = Depends(optional_fastapi_current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(optional_fastapi_current_user),
|
||||
) -> User | None:
|
||||
versioned_fetch_user = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "optional_user_"
|
||||
@@ -495,3 +693,186 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Danswer MIT
|
||||
return []
|
||||
|
||||
|
||||
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
|
||||
|
||||
class OAuth2AuthorizeResponse(BaseModel):
|
||||
authorization_url: str
|
||||
|
||||
|
||||
def generate_state_token(
|
||||
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
|
||||
) -> str:
|
||||
data["aud"] = STATE_TOKEN_AUDIENCE
|
||||
|
||||
return generate_jwt(data, secret, lifetime_seconds)
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
|
||||
|
||||
def create_danswer_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
state_secret: SecretType,
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> APIRouter:
|
||||
return get_oauth_router(
|
||||
oauth_client,
|
||||
backend,
|
||||
get_user_manager,
|
||||
state_secret,
|
||||
redirect_url,
|
||||
associate_by_email,
|
||||
is_verified_by_default,
|
||||
)
|
||||
|
||||
|
||||
def get_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
state_secret: SecretType,
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the OAuth routes."""
|
||||
router = APIRouter()
|
||||
callback_route_name = f"oauth:{oauth_client.name}.{backend.name}.callback"
|
||||
|
||||
if redirect_url is not None:
|
||||
oauth2_authorize_callback = OAuth2AuthorizeCallback(
|
||||
oauth_client,
|
||||
redirect_url=redirect_url,
|
||||
)
|
||||
else:
|
||||
oauth2_authorize_callback = OAuth2AuthorizeCallback(
|
||||
oauth_client,
|
||||
route_name=callback_route_name,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/authorize",
|
||||
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
|
||||
response_model=OAuth2AuthorizeResponse,
|
||||
)
|
||||
async def authorize(
|
||||
request: Request, scopes: List[str] = Query(None)
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
state_data: Dict[str, str] = {"next_url": next_url}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@router.get(
|
||||
"/callback",
|
||||
name=callback_route_name,
|
||||
description="The response varies based on the authentication backend used.",
|
||||
responses={
|
||||
status.HTTP_400_BAD_REQUEST: {
|
||||
"model": ErrorModel,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"examples": {
|
||||
"INVALID_STATE_TOKEN": {
|
||||
"summary": "Invalid state token.",
|
||||
"value": None,
|
||||
},
|
||||
ErrorCode.LOGIN_BAD_CREDENTIALS: {
|
||||
"summary": "User is inactive.",
|
||||
"value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
async def callback(
|
||||
request: Request,
|
||||
access_token_state: Tuple[OAuth2Token, str] = Depends(
|
||||
oauth2_authorize_callback
|
||||
),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
) -> RedirectResponse:
|
||||
token, state = access_token_state
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
token["access_token"]
|
||||
)
|
||||
|
||||
if account_email is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
|
||||
)
|
||||
|
||||
try:
|
||||
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
|
||||
# Authenticate user
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
token["access_token"],
|
||||
account_id,
|
||||
account_email,
|
||||
token.get("expires_at"),
|
||||
token.get("refresh_token"),
|
||||
request,
|
||||
associate_by_email=associate_by_email,
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
except UserAlreadyExists:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS,
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
|
||||
# Login user
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
|
||||
# Prepare redirect response
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
|
||||
# Copy headers and other attributes from 'response' to 'redirect_response'
|
||||
for header_name, header_value in response.headers.items():
|
||||
redirect_response.headers[header_name] = header_value
|
||||
|
||||
if hasattr(response, "body"):
|
||||
redirect_response.body = response.body
|
||||
if hasattr(response, "status_code"):
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,11 +15,13 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
@@ -27,8 +29,8 @@ class RedisObjectHelper(ABC):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int):
|
||||
self._id: int = id
|
||||
def __init__(self, id: str):
|
||||
self._id: str = id
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
@@ -45,7 +47,7 @@ class RedisObjectHelper(ABC):
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> int | None:
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
@@ -59,15 +61,11 @@ class RedisObjectHelper(ABC):
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[2])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
object_id = parts[2]
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> int | None:
|
||||
def get_id_from_task_id(task_id: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
@@ -91,11 +89,7 @@ class RedisObjectHelper(ABC):
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
object_id = parts[1]
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
@@ -105,6 +99,7 @@ class RedisObjectHelper(ABC):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
@@ -114,17 +109,21 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(self._id, current_only=False)
|
||||
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
@@ -134,7 +133,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
@@ -144,7 +143,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
@@ -160,17 +159,24 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
if not global_version.is_ee_version():
|
||||
return 0
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
@@ -179,7 +185,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(self._id)
|
||||
stmt = construct_document_select_by_usergroup(int(self._id))
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
@@ -189,7 +195,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
@@ -199,7 +205,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
@@ -211,10 +217,19 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
"""This class is used to scan documents by cc_pair in the db and collect them into
|
||||
a unified set for syncing.
|
||||
|
||||
It differs from the other redis helpers in that the taskset used spans
|
||||
all connectors and is not per connector."""
|
||||
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
@@ -236,17 +251,78 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorDeletion(RedisObjectHelper):
|
||||
PREFIX = "connectordeletion"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
@@ -263,15 +339,18 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
@@ -281,6 +360,185 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorPruning(RedisObjectHelper):
|
||||
"""Celery will kick off a long running generator task to crawl the connector and
|
||||
find any missing docs, which will each then get a new cleanup task. The progress of
|
||||
those tasks will then be monitored to completion.
|
||||
|
||||
Example rough happy path order:
|
||||
Check connectorpruning_fence_1
|
||||
Send generator task with id connectorpruning+generator_1_{uuid}
|
||||
|
||||
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
|
||||
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
|
||||
in taskset connectorpruning_taskset_1
|
||||
on completion, generator sets connectorpruning_generator_complete_1
|
||||
|
||||
celery postrun removes subtasks from taskset
|
||||
monitor beat task cleans up when taskset reaches 0 items
|
||||
"""
|
||||
|
||||
PREFIX = "connectorpruning"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
|
||||
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
SUBTASK_PREFIX = PREFIX + "+sub"
|
||||
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # a signal that contains generator progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # a signal that the generator has finished
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
self.documents_to_prune: set[str] = set()
|
||||
|
||||
@property
|
||||
def generator_task_id_prefix(self) -> str:
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_progress_key(self) -> str:
|
||||
# example: connectorpruning_generator_progress_1
|
||||
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_complete_key(self) -> str:
|
||||
# example: connectorpruning_generator_complete_1
|
||||
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def subtask_id_prefix(self) -> str:
|
||||
return f"{self.SUBTASK_PREFIX}_{self._id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
for doc_id in self.documents_to_prune:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=int(self._id), db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair_id {self._id} does not exist.")
|
||||
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class RedisConnectorIndexing(RedisObjectHelper):
|
||||
"""Celery will kick off a long running indexing task to crawl the connector and
|
||||
find any new or updated docs docs, which will each then get a new sync task or be
|
||||
indexed inline.
|
||||
|
||||
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
|
||||
e.g. "2/5"
|
||||
"""
|
||||
|
||||
PREFIX = "connectorindexing"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
|
||||
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
SUBTASK_PREFIX = PREFIX + "+sub"
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # a signal that contains generator progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # a signal that the generator has finished
|
||||
|
||||
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
|
||||
super().__init__(f"{cc_pair_id}/{search_settings_id}")
|
||||
|
||||
@property
|
||||
def generator_lock_key(self) -> str:
|
||||
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_task_id_prefix(self) -> str:
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_progress_key(self) -> str:
|
||||
# example: connectorpruning_generator_progress_1
|
||||
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_complete_key(self) -> str:
|
||||
# example: connectorpruning_generator_complete_1
|
||||
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def subtask_id_prefix(self) -> str:
|
||||
return f"{self.SUBTASK_PREFIX}_{self._id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
@@ -15,35 +18,53 @@ from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import get_latest_task_by_type
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_deletion_status(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> TaskQueueState | None:
|
||||
cleanup_task_name = name_cc_cleanup_task(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
"""
|
||||
cc_pair = get_connector_credential_pair(
|
||||
connector_id=connector_id, credential_id=credential_id, db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
if not r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
return TaskQueueState(
|
||||
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
|
||||
)
|
||||
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
|
||||
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id, credential_id, db_session, tenant_id
|
||||
)
|
||||
if not deletion_task:
|
||||
return None
|
||||
|
||||
@@ -54,78 +75,19 @@ def get_deletion_attempt_snapshot(
|
||||
)
|
||||
|
||||
|
||||
def should_kick_off_deletion_of_cc_pair(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
if check_deletion_attempt_is_allowed(cc_pair, db_session):
|
||||
return False
|
||||
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if deletion_task and check_task_is_live_and_not_timed_out(
|
||||
deletion_task,
|
||||
db_session,
|
||||
# 1 hour timeout
|
||||
timeout=60 * 60,
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
if not connector.prune_freq:
|
||||
return False
|
||||
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=connector.id, credential_id=credential.id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
|
||||
if not last_pruning_task:
|
||||
time_since_initialization = current_db_time - connector.time_created
|
||||
if time_since_initialization.total_seconds() >= connector.prune_freq:
|
||||
return True
|
||||
return False
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
)
|
||||
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
return False
|
||||
|
||||
time_since_last_pruning = current_db_time - last_pruning_task.start_time
|
||||
return time_since_last_pruning.total_seconds() >= connector.prune_freq
|
||||
|
||||
|
||||
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
@@ -148,6 +110,56 @@ def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> se
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
if progress_callback:
|
||||
progress_callback(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
|
||||
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
||||
"""Checks to see if we're listening to the named queue"""
|
||||
|
||||
# how to get a list of queues this worker is listening to
|
||||
# https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime
|
||||
queue_names = list(worker.app.amqp.queues.consume_from.keys())
|
||||
for queue_name in queue_names:
|
||||
if queue_name == name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_is_worker_primary(worker: Any) -> bool:
|
||||
"""There are multiple approaches that could be taken to determine if a celery worker
|
||||
is 'primary', as defined by us. But the way we do it is to check the hostname set
|
||||
for the celery worker, which can be done either in celeryconfig.py or on the
|
||||
command line with '--hostname'."""
|
||||
hostname = worker.hostname
|
||||
if hostname.startswith("primary"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
with get_session_with_tenant(tenant_id="public") as session:
|
||||
result = session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
||||
)
|
||||
)
|
||||
tenant_ids = [row[0] for row in result]
|
||||
|
||||
valid_tenants = [
|
||||
tenant
|
||||
for tenant in tenant_ids
|
||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
return valid_tenants
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
import urllib.parse
|
||||
|
||||
from danswer.configs.app_configs import CELERY_BROKER_POOL_LIMIT
|
||||
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
|
||||
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
@@ -7,12 +13,13 @@ from danswer.configs.app_configs import REDIS_SSL
|
||||
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
|
||||
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
CELERY_SEPARATOR = ":"
|
||||
|
||||
CELERY_PASSWORD_PART = ""
|
||||
if REDIS_PASSWORD:
|
||||
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
|
||||
CELERY_PASSWORD_PART = ":" + urllib.parse.quote(REDIS_PASSWORD, safe="") + "@"
|
||||
|
||||
REDIS_SCHEME = "redis"
|
||||
|
||||
@@ -27,18 +34,71 @@ if REDIS_SSL:
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
||||
|
||||
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
|
||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
||||
# can stall other tasks.
|
||||
worker_prefetch_multiplier = 4
|
||||
|
||||
# Leaving this to the default of True may cause double logging since both our own app
|
||||
# and celery think they are controlling the logger.
|
||||
# TODO: Configure celery's logger entirely manually and set this to False
|
||||
# worker_hijack_root_logger = False
|
||||
|
||||
broker_connection_retry_on_startup = True
|
||||
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
|
||||
|
||||
# redis broker settings
|
||||
# https://docs.celeryq.dev/projects/kombu/en/stable/reference/kombu.transport.redis.html
|
||||
broker_transport_options = {
|
||||
"priority_steps": list(range(len(DanswerCeleryPriority))),
|
||||
"sep": CELERY_SEPARATOR,
|
||||
"queue_order_strategy": "priority",
|
||||
"retry_on_timeout": True,
|
||||
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
|
||||
"socket_keepalive": True,
|
||||
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
}
|
||||
|
||||
# redis backend settings
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
|
||||
# there doesn't appear to be a way to set socket_keepalive_options on the redis result backend
|
||||
redis_socket_keepalive = True
|
||||
redis_retry_on_timeout = True
|
||||
redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
|
||||
|
||||
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
|
||||
# It's possible we don't even need celery's result backend, in which case all of the optimization below
|
||||
# might be irrelevant
|
||||
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
||||
|
||||
# Option 0: Defaults (json serializer, no compression)
|
||||
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
|
||||
|
||||
# Option 1: Reduces generator task result sizes by roughly 20%
|
||||
# task_compression = "bzip2"
|
||||
# task_serializer = "pickle"
|
||||
# result_compression = "bzip2"
|
||||
# result_serializer = "pickle"
|
||||
# accept_content=["pickle"]
|
||||
|
||||
# Option 2: this significantly reduces the size of the result for generator tasks since the list of children
|
||||
# can be large. small tasks change very little
|
||||
# def pickle_bz2_encoder(data):
|
||||
# return bz2.compress(pickle.dumps(data))
|
||||
|
||||
# def pickle_bz2_decoder(data):
|
||||
# return pickle.loads(bz2.decompress(data))
|
||||
|
||||
# from kombu import serialization # To register custom serialization with Celery/Kombu
|
||||
|
||||
# serialization.register('pickle-bzip2', pickle_bz2_encoder, pickle_bz2_decoder, 'application/x-pickle-bz2', 'binary')
|
||||
|
||||
# task_serializer = "pickle-bzip2"
|
||||
# result_serializer = "pickle-bzip2"
|
||||
# accept_content=["pickle", "pickle-bzip2"]
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
Returns None if no syncing is required.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
# we need to refresh the state of the object inside the fence
|
||||
# to avoid a race condition with db.commit/fence deletion
|
||||
# at the end of this taskset
|
||||
try:
|
||||
db_session.refresh(cc_pair)
|
||||
except ObjectDeletedError:
|
||||
return None
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcd.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
455
backend/danswer/background/celery/tasks/indexing/tasks.py
Normal file
455
backend/danswer/background/celery/tasks/indexing/tasks.py
Normal file
@@ -0,0 +1,455 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import shared_task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
)
|
||||
def check_for_indexing(*, tenant_id: str | None) -> int | None:
|
||||
tasks_created = 0
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
|
||||
return None
|
||||
else:
|
||||
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
for search_settings_instance in search_settings:
|
||||
rci = RedisConnectorIndexing(
|
||||
cc_pair.id, search_settings_instance.id
|
||||
)
|
||||
if r.exists(rci.fence_key):
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
if not _should_index(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
cc_pair,
|
||||
search_settings_instance,
|
||||
False,
|
||||
db_session,
|
||||
r,
|
||||
tenant_id,
|
||||
)
|
||||
if attempt_id:
|
||||
task_logger.info(
|
||||
f"Indexing queued: cc_pair_id={cc_pair.id} index_attempt_id={attempt_id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return tasks_created
|
||||
|
||||
|
||||
def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Checks various global settings and past indexing attempts to determine if
|
||||
we should try to start indexing the cc pair / search setting combination.
|
||||
|
||||
Note that tactical checks such as preventing overlap with a currently running task
|
||||
are not handled here.
|
||||
|
||||
Return True if we should try to index, False if not.
|
||||
"""
|
||||
connector = cc_pair.connector
|
||||
|
||||
# uncomment for debugging
|
||||
# task_logger.info(f"_should_index: "
|
||||
# f"cc_pair={cc_pair.id} "
|
||||
# f"connector={cc_pair.connector_id} "
|
||||
# f"refresh_freq={connector.refresh_freq}")
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if (
|
||||
not cc_pair.status.is_active()
|
||||
or connector.id == 0
|
||||
or connector.source == DocumentSource.INGESTION_API
|
||||
):
|
||||
return False
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
if time_since_index.total_seconds() < connector.refresh_freq:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_indexing_task(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
|
||||
|
||||
# skip if already indexing
|
||||
if r.exists(rci.fence_key):
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.taskset_key)
|
||||
|
||||
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
# create the index attempt ... just for tracking purposes
|
||||
index_attempt_id = create_index_attempt(
|
||||
cc_pair.id,
|
||||
search_settings.id,
|
||||
from_beginning=reindex,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_INDEXING,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
# set this only after all tasks have been added
|
||||
fence_value = RedisConnectorIndexingFenceData(
|
||||
index_attempt_id=index_attempt_id,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
|
||||
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
|
||||
def connector_indexing_proxy_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
connector_indexing_task,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
tenant_id,
|
||||
global_version.is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
|
||||
if not job:
|
||||
return
|
||||
|
||||
while True:
|
||||
sleep(10)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
break
|
||||
|
||||
return
|
||||
|
||||
|
||||
def connector_indexing_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
is_ee: bool,
|
||||
) -> int | None:
|
||||
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list
|
||||
|
||||
acks_late must be set to False. Otherwise, celery's visibility timeout will
|
||||
cause any task that runs longer than the timeout to be redispatched by the broker.
|
||||
There appears to be no good workaround for this, so we need to handle redispatching
|
||||
manually.
|
||||
|
||||
Returns None if the task did not run (possibly due to a conflict).
|
||||
Otherwise, returns an int >= 0 representing the number of indexed docs.
|
||||
"""
|
||||
|
||||
attempt = None
|
||||
n_final_progress = 0
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
lock = r.lock(
|
||||
rci.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}"
|
||||
)
|
||||
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise ValueError(
|
||||
f"Index attempt not found: index_attempt_id={index_attempt_id}"
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}")
|
||||
|
||||
if not cc_pair.connector:
|
||||
raise ValueError(
|
||||
f"Connector not found: connector_id={cc_pair.connector_id}"
|
||||
)
|
||||
|
||||
if not cc_pair.credential:
|
||||
raise ValueError(
|
||||
f"Credential not found: credential_id={cc_pair.credential_id}"
|
||||
)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
# Define the callback function
|
||||
def redis_increment_callback(amount: int) -> None:
|
||||
lock.reacquire()
|
||||
r.incrby(rci.generator_progress_key, amount)
|
||||
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
progress_callback=redis_increment_callback,
|
||||
)
|
||||
|
||||
# get back the total number of indexed docs and return it
|
||||
generator_progress_value = r.get(rci.generator_progress_key)
|
||||
if generator_progress_value is not None:
|
||||
try:
|
||||
n_final_progress = int(cast(int, generator_progress_value))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.")
|
||||
if attempt:
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.fence_key)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return n_final_progress
|
||||
137
backend/danswer/background/celery/tasks/periodic/tasks.py
Normal file
137
backend/danswer/background/celery/tasks/periodic/tasks.py
Normal file
@@ -0,0 +1,137 @@
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery.contrib.abortable import AbortableTask # type: ignore
|
||||
from celery.exceptions import TaskRevokedError
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="kombu_message_cleanup_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
|
||||
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
|
||||
|
||||
ctx = {}
|
||||
ctx["last_processed_id"] = 0
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
|
||||
).scalar()
|
||||
if not result:
|
||||
return 0
|
||||
|
||||
while True:
|
||||
if self.is_aborted():
|
||||
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
|
||||
|
||||
b = kombu_message_cleanup_task_helper(ctx, db_session)
|
||||
if not b:
|
||||
break
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if ctx["deleted"] > 0:
|
||||
task_logger.info(
|
||||
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
|
||||
)
|
||||
|
||||
return ctx["deleted"]
|
||||
|
||||
|
||||
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
|
||||
"""
|
||||
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
|
||||
|
||||
This function retrieves messages from the `kombu_message` table that are no longer visible and
|
||||
older than a specified interval. It checks if the corresponding task_id exists in the
|
||||
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
|
||||
|
||||
Args:
|
||||
ctx (dict): A context dictionary containing configuration parameters such as:
|
||||
- 'cleanup_age' (int): The age in days after which messages are considered old.
|
||||
- 'page_limit' (int): The maximum number of messages to process in one batch.
|
||||
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
|
||||
- 'deleted' (int): A counter to track the number of deleted messages.
|
||||
db_session (Session): The SQLAlchemy database session for executing queries.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if there are more rows to process, False if not.
|
||||
"""
|
||||
|
||||
inspector = inspect(db_session.bind)
|
||||
if not inspector:
|
||||
return False
|
||||
|
||||
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
|
||||
# We can fail silently.
|
||||
if not inspector.has_table("kombu_message"):
|
||||
return False
|
||||
|
||||
query = text(
|
||||
"""
|
||||
SELECT id, timestamp, payload
|
||||
FROM kombu_message WHERE visible = 'false'
|
||||
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
|
||||
AND id > :last_processed_id
|
||||
ORDER BY id
|
||||
LIMIT :page_limit
|
||||
"""
|
||||
)
|
||||
kombu_messages = db_session.execute(
|
||||
query,
|
||||
{
|
||||
"interval_days": f"{ctx['cleanup_age']} days",
|
||||
"page_limit": ctx["page_limit"],
|
||||
"last_processed_id": ctx["last_processed_id"],
|
||||
},
|
||||
).fetchall()
|
||||
|
||||
if len(kombu_messages) == 0:
|
||||
return False
|
||||
|
||||
for msg in kombu_messages:
|
||||
payload = json.loads(msg[2])
|
||||
task_id = payload["headers"]["id"]
|
||||
|
||||
# Check if task_id exists in celery_taskmeta
|
||||
task_exists = db_session.execute(
|
||||
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
|
||||
{"task_id": task_id},
|
||||
).fetchone()
|
||||
|
||||
# If task_id does not exist, delete the message
|
||||
if not task_exists:
|
||||
result = db_session.execute(
|
||||
text("DELETE FROM kombu_message WHERE id = :message_id"),
|
||||
{"message_id": msg[0]},
|
||||
)
|
||||
if result.rowcount > 0: # type: ignore
|
||||
ctx["deleted"] += 1
|
||||
|
||||
ctx["last_processed_id"] = msg[0]
|
||||
|
||||
return True
|
||||
301
backend/danswer/background/celery/tasks/pruning/tasks.py
Normal file
301
backend/danswer/background/celery/tasks/pruning/tasks.py
Normal file
@@ -0,0 +1,301 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import shared_task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_pruning",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_pruning(*, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
lock_beat.reacquire()
|
||||
if not is_pruning_due(cc_pair, db_session, r):
|
||||
continue
|
||||
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Pruning queued: cc_pair_id={cc_pair.id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def is_pruning_due(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
) -> bool:
|
||||
"""Returns an int if pruning is triggered.
|
||||
The int represents the number of prune tasks generated (in this case, only one
|
||||
because the task is a long running generator task.)
|
||||
Returns None if no pruning is triggered (due to not being needed or
|
||||
other reasons such as simultaneous pruning restrictions.
|
||||
|
||||
Checks for scheduling related conditions, then delegates the rest of the checks to
|
||||
try_creating_prune_generator_task.
|
||||
"""
|
||||
|
||||
# skip pruning if no prune frequency is set
|
||||
# pruning can still be forced via the API which will run a pruning task directly
|
||||
if not cc_pair.connector.prune_freq:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
# skip pruning if the next scheduled prune time hasn't been reached yet
|
||||
last_pruned = cc_pair.last_pruned
|
||||
if not last_pruned:
|
||||
if not cc_pair.last_successful_index_time:
|
||||
# if we've never indexed, we can't prune
|
||||
return False
|
||||
|
||||
# if never pruned, use the last time the connector indexed successfully
|
||||
last_pruned = cc_pair.last_successful_index_time
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger prunes immediately, e.g. via the web ui.
|
||||
"""
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
return None
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# we need to serialize starting pruning since it can be triggered either via
|
||||
# celery beat or manually (API call)
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
rcp = RedisConnectorPruning(cc_pair.id)
|
||||
|
||||
# skip pruning if already pruning
|
||||
if r.exists(rcp.fence_key):
|
||||
return None
|
||||
|
||||
# skip pruning if the cc_pair is deleting
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
|
||||
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
"connector_pruning_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcp.fence_key, 1)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_pruning_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
)
|
||||
def connector_pruning_generator_task(
|
||||
cc_pair_id: int, connector_id: int, credential_id: int, tenant_id: str | None
|
||||
) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"cc_pair not found for {connector_id} {credential_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Define the callback function
|
||||
def redis_increment_callback(amount: int) -> None:
|
||||
lock.reacquire()
|
||||
r.incrby(rcp.generator_progress_key, amount)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector, redis_increment_callback
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning set collected: "
|
||||
f"cc_pair_id={cc_pair.id} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)} "
|
||||
f"doc_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
rcp.documents_to_prune = set(doc_ids_to_remove)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = rcp.generate_tasks(
|
||||
celery_app, db_session, r, None, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorPruning.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(rcp.generator_complete_key, tasks_generated)
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Failed to run pruning for connector id {connector_id}.")
|
||||
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.fence_key)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
144
backend/danswer/background/celery/tasks/shared/tasks.py
Normal file
144
backend/danswer/background/celery/tasks/shared/tasks.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_connector_count
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
|
||||
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
index_attempt_id: int
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def document_by_cc_pair_cleanup_task(
|
||||
self: Task,
|
||||
document_id: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> bool:
|
||||
"""A lightweight subtask used to clean up document to cc pair relationships.
|
||||
Created by connection deletion and connector pruning parent tasks."""
|
||||
|
||||
"""
|
||||
To delete a connector / credential pair:
|
||||
(1) find all documents associated with connector / credential pair where there
|
||||
this the is only connector / credential pair that has indexed it
|
||||
(2) delete all documents from document stores
|
||||
(3) delete all entries from postgres
|
||||
(4) find all documents associated with connector / credential pair where there
|
||||
are multiple connector / credential pairs that have indexed it
|
||||
(5) update document store entries to remove access associated with the
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
action = "skip"
|
||||
chunks_affected = 0
|
||||
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
# count == 1 means this is the only remaining cc_pair reference to the doc
|
||||
# delete it from vespa and the db
|
||||
action = "delete"
|
||||
|
||||
chunks_affected = document_index.delete_single(document_id)
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
)
|
||||
elif count > 1:
|
||||
action = "update"
|
||||
|
||||
# count > 1 means the document still has cc_pair references
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# the below functions do not include cc_pairs being deleted.
|
||||
# i.e. they will correctly omit access for the current cc_pair
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
fields = VespaDocumentFields(
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = document_index.update_single(
|
||||
document_id, fields=fields
|
||||
)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
else:
|
||||
pass
|
||||
|
||||
task_logger.info(
|
||||
f"tenant_id={tenant_id} "
|
||||
f"document_id={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected}"
|
||||
)
|
||||
db_session.commit()
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
f"SoftTimeLimitExceeded exception. tenant_id={tenant_id} doc_id={document_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
806
backend/danswer/background/celery/tasks/vespa/tasks.py
Normal file
806
backend/danswer/background/celery/tasks/vespa/tasks.py
Normal file
@@ -0,0 +1,806 @@
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import celery_get_queue_length
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_ccpair_as_pruned
|
||||
from danswer.db.connector_credential_pair import add_deletion_failure_message
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import count_documents_by_needs_sync
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id)
|
||||
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
try_generate_document_set_sync_tasks(
|
||||
document_set, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
# check if any user groups are not synced
|
||||
if global_version.is_ee_version():
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_groups"
|
||||
)
|
||||
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
for usergroup in user_groups:
|
||||
try_generate_user_group_sync_tasks(
|
||||
usergroup, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
# Always exceptions on the MIT version, which is expected
|
||||
# We shouldn't actually get here if the ee version check works
|
||||
pass
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
return None
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
stale_doc_count = count_documents_by_needs_sync(db_session)
|
||||
if stale_doc_count == 0:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||
)
|
||||
|
||||
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
|
||||
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
total_tasks_generated = 0
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
if tasks_generated is None:
|
||||
continue
|
||||
|
||||
if tasks_generated == 0:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for single cc_pair. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
total_tasks_generated += tasks_generated
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
|
||||
return total_tasks_generated
|
||||
|
||||
|
||||
def try_generate_document_set_sync_tasks(
|
||||
document_set: DocumentSet,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(document_set.id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if r.exists(rds.fence_key):
|
||||
return None
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
db_session.refresh(document_set)
|
||||
if document_set.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rds.taskset_key)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
tasks_generated = rds.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks finished. "
|
||||
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rds.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def try_generate_user_group_sync_tasks(
|
||||
usergroup: UserGroup,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(usergroup.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
db_session.refresh(usergroup)
|
||||
if usergroup.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rug.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||
)
|
||||
tasks_generated = rug.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks finished. "
|
||||
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rug.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def monitor_connector_taskset(r: Redis) -> None:
|
||||
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
|
||||
task_logger.info(
|
||||
f"Stale document sync progress: remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count == 0:
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
if document_set_id_str is None:
|
||||
task_logger.warning(f"could not parse document set id from {fence_key}")
|
||||
return
|
||||
|
||||
document_set_id = int(document_set_id_str)
|
||||
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
fence_value = r.get(rds.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"Document set sync progress: document_set_id={document_set_id} "
|
||||
f"remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if document_set:
|
||||
if not document_set.connector_credential_pairs:
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
|
||||
r.delete(rds.taskset_key)
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
key_bytes: bytes, r: Redis, tenant_id: str | None
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcd.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
doc_ids = get_document_ids_for_connector_credential_pair(
|
||||
db_session, cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
if len(doc_ids) > 0:
|
||||
# if this happens, documents somehow got added while deletion was in progress. Likely a bug
|
||||
# gating off pruning and indexing work before deletion starts
|
||||
task_logger.warning(
|
||||
f"Connector deletion - documents still found after taskset completion: "
|
||||
f"cc_pair={cc_pair_id} num={len(doc_ids)}"
|
||||
)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
task_logger.exception(
|
||||
f"Failed to run connector_deletion. "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Successfully deleted cc_pair: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={initial_count}"
|
||||
)
|
||||
|
||||
r.delete(rcd.taskset_key)
|
||||
r.delete(rcd.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcp.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
generator_value = r.get(rcp.generator_complete_key)
|
||||
if generator_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, generator_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcp.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rci.fence_key))
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
|
||||
)
|
||||
raise
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
|
||||
|
||||
generator_progress_value = r.get(rci.generator_progress_key)
|
||||
if generator_progress_value is not None:
|
||||
try:
|
||||
progress_count = int(cast(int, generator_progress_value))
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"progress={progress_count} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
|
||||
)
|
||||
|
||||
# Read result state BEFORE generator_complete_key to avoid a race condition
|
||||
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
|
||||
result_state = result.state
|
||||
|
||||
generator_complete_value = r.get(rci.generator_complete_key)
|
||||
if generator_complete_value is None:
|
||||
if result_state in READY_STATES:
|
||||
# IF the task state is READY, THEN generator_complete should be set
|
||||
# if it isn't, then the worker crashed
|
||||
task_logger.info(
|
||||
f"Connector indexing aborted: "
|
||||
f"cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
|
||||
if index_attempt:
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
db_session=db_session,
|
||||
failure_reason="Connector indexing aborted or exceptioned.",
|
||||
)
|
||||
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.fence_key)
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
try:
|
||||
status_value = int(cast(int, generator_complete_value))
|
||||
status_enum = HTTPStatus(status_value)
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"monitor_ccpair_indexing_taskset: "
|
||||
f"generator_complete_value=f{generator_complete_value} could not be parsed."
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.fence_key)
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
|
||||
Returns True if the task actually did work, False
|
||||
"""
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: redis.lock.Lock = r.lock(
|
||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return False
|
||||
|
||||
# print current queue lengths
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r)
|
||||
n_indexing = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(
|
||||
DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery
|
||||
)
|
||||
n_deletion = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
# check the db for any outstanding index attempts
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for a in attempts:
|
||||
# if attempts exist in the db but we don't detect them in redis, mark them as failed
|
||||
rci = RedisConnectorIndexing(
|
||||
a.connector_credential_pair_id, a.search_settings_id
|
||||
)
|
||||
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
|
||||
if not r.exists(rci.fence_key):
|
||||
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, tenant_id: str | None
|
||||
) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# document set sync
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
# User group sync
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
|
||||
fields = VespaDocumentFields(
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = document_index.update_single(document_id, fields=fields)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
task_logger.info(
|
||||
f"document_id={document_id} action=sync chunks={chunks_affected}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
@@ -1,195 +0,0 @@
|
||||
"""
|
||||
To delete a connector / credential pair:
|
||||
(1) find all documents associated with connector / credential pair where there
|
||||
this the is only connector / credential pair that has indexed it
|
||||
(2) delete all documents from document stores
|
||||
(3) delete all entries from postgres
|
||||
(4) find all documents associated with connector / credential pair where there
|
||||
are multiple connector / credential pairs that have indexed it
|
||||
(5) update document store entries to remove access associated with the
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document_connector_cnts
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def delete_connector_credential_pair_batch(
|
||||
document_ids: list[str],
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
document_index: DocumentIndex,
|
||||
) -> None:
|
||||
"""
|
||||
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
|
||||
it gets permanently deleted.
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# acquire lock for all documents in this batch so that indexing can't
|
||||
# override the deletion
|
||||
with prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
):
|
||||
document_connector_cnts = get_document_connector_cnts(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
|
||||
# figure out which docs need to be completely deleted
|
||||
document_ids_to_delete = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt == 1
|
||||
]
|
||||
logger.debug(f"Deleting documents: {document_ids_to_delete}")
|
||||
|
||||
document_index.delete(doc_ids=document_ids_to_delete)
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_delete,
|
||||
)
|
||||
|
||||
# figure out which docs need to be updated
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt > 1
|
||||
]
|
||||
|
||||
# maps document id to list of document set names
|
||||
new_doc_sets_for_documents: dict[str, set[str]] = {
|
||||
document_id_and_document_set_names_tuple[0]: set(
|
||||
document_id_and_document_set_names_tuple[1]
|
||||
)
|
||||
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
)
|
||||
}
|
||||
|
||||
# determine future ACLs for documents in batch
|
||||
access_for_documents = get_access_for_documents(
|
||||
document_ids=document_ids_to_update,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=access,
|
||||
document_sets=new_doc_sets_for_documents[document_id],
|
||||
)
|
||||
for document_id, access in access_for_documents.items()
|
||||
]
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
# clean up Postgres
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_connector_credential_pair(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> int:
|
||||
connector_id = cc_pair.connector_id
|
||||
credential_id = cc_pair.credential_id
|
||||
|
||||
num_docs_deleted = 0
|
||||
while True:
|
||||
documents = get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
limit=_DELETION_BATCH_SIZE,
|
||||
)
|
||||
if not documents:
|
||||
break
|
||||
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=[document.id for document in documents],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
)
|
||||
num_docs_deleted += len(documents)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
logger.info("Found no credentials left for connector, deleting connector")
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
logger.notice(
|
||||
"Successfully deleted connector_credential_pair with connector_id:"
|
||||
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
|
||||
)
|
||||
return num_docs_deleted
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -14,21 +15,22 @@ from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from danswer.connectors.connector_runner import ConnectorRunner
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import transition_attempt_to_in_progress
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -44,11 +46,12 @@ def _get_connector_runner(
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None,
|
||||
) -> ConnectorRunner:
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
|
||||
Returns an interator of document batches and whether the returned documents
|
||||
Returns an iterator of document batches and whether the returned documents
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
@@ -56,22 +59,28 @@ def _get_connector_runner(
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector_credential_pair.connector.source,
|
||||
task,
|
||||
attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
attempt.connector_credential_pair.credential,
|
||||
db_session,
|
||||
db_session=db_session,
|
||||
source=attempt.connector_credential_pair.connector.source,
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
# since we failed to even instantiate the connector, we pause the CCPair since
|
||||
# it will never succeed
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector_credential_pair.connector.id,
|
||||
credential_id=attempt.connector_credential_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.PAUSED,
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
attempt.connector_credential_pair.id, db_session
|
||||
)
|
||||
if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector_credential_pair.connector.id,
|
||||
credential_id=attempt.connector_credential_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.PAUSED,
|
||||
)
|
||||
raise e
|
||||
|
||||
return ConnectorRunner(
|
||||
@@ -82,11 +91,16 @@ def _get_connector_runner(
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
tenant_id: str | None,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
|
||||
TODO: do not change index attempt statuses here ... instead, set signals in redis
|
||||
and allow the monitor function to clean them up
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -103,16 +117,26 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
search_settings=search_settings,
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=index_attempt.id,
|
||||
db_session=db_session,
|
||||
# let the world know we're still making progress after
|
||||
# every 10 batches
|
||||
freq=10,
|
||||
),
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt.id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE),
|
||||
ignore_time_skip=(
|
||||
index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
db_cc_pair = index_attempt.connector_credential_pair
|
||||
@@ -169,6 +193,7 @@ def _run_indexing(
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
@@ -181,7 +206,7 @@ def _run_indexing(
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
db_session.refresh(db_connector)
|
||||
db_session.refresh(db_cc_pair)
|
||||
if (
|
||||
(
|
||||
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
@@ -196,7 +221,9 @@ def _run_indexing(
|
||||
db_session.refresh(index_attempt)
|
||||
if index_attempt.status != IndexingStatus.IN_PROGRESS:
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError("Index Attempt was canceled")
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt.status}"
|
||||
)
|
||||
|
||||
batch_description = []
|
||||
for doc in doc_batch:
|
||||
@@ -216,6 +243,8 @@ def _run_indexing(
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
|
||||
# real work happens here!
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
@@ -234,6 +263,9 @@ def _run_indexing(
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(len(doc_batch))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
@@ -357,40 +389,13 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
|
||||
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
|
||||
# make sure that the index attempt can't change in between checking the
|
||||
# status and marking it as in_progress. This setting will be discarded
|
||||
# after the next commit:
|
||||
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
|
||||
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
|
||||
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
if attempt is None:
|
||||
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
|
||||
|
||||
if attempt.status != IndexingStatus.NOT_STARTED:
|
||||
raise RuntimeError(
|
||||
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
|
||||
f"Current status is '{attempt.status}'."
|
||||
)
|
||||
|
||||
# only commit once, to make sure this all happens in a single transaction
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
|
||||
return attempt
|
||||
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
|
||||
try:
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
@@ -400,26 +405,29 @@ def run_indexing_entrypoint(
|
||||
IndexAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# make sure that it is valid to run this indexing attempt + mark it
|
||||
# as in progress
|
||||
attempt = _prepare_index_attempt(db_session, index_attempt_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"Indexing starting for tenant {tenant_id}: "
|
||||
if tenant_id is not None
|
||||
else ""
|
||||
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt)
|
||||
_run_indexing(db_session, attempt, tenant_id, progress_callback)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"Indexing finished for tenant {tenant_id}: "
|
||||
if tenant_id is not None
|
||||
else ""
|
||||
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
logger.exception(
|
||||
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
|
||||
)
|
||||
|
||||
@@ -14,14 +14,6 @@ from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
|
||||
|
||||
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
|
||||
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
|
||||
|
||||
def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
return f"sync_doc_set_{document_set_id}"
|
||||
|
||||
|
||||
def name_cc_prune_task(
|
||||
connector_id: int | None = None, credential_id: int | None = None
|
||||
) -> str:
|
||||
|
||||
@@ -1,475 +1,494 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import dask
|
||||
from dask.distributed import Client
|
||||
from dask.distributed import Future
|
||||
from distributed import LocalCluster
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
from danswer.background.indexing.job_client import SimpleJob
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If the indexing dies, it's most likely due to resource constraints,
|
||||
# restarting just delays the eventual failure, not useful to the user
|
||||
dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
||||
|
||||
_UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
"Stopped mid run, likely due to the background process being killed"
|
||||
)
|
||||
|
||||
|
||||
def _should_create_new_indexing(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
connector = cc_pair.connector
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if not cc_pair.status.is_active() or connector.id == 0:
|
||||
return False
|
||||
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
|
||||
# Only one scheduled/ongoing job per connector at a time
|
||||
# this prevents cases where
|
||||
# (1) the "latest" index_attempt is scheduled so we show
|
||||
# that in the UI despite another index_attempt being in-progress
|
||||
# (2) multiple scheduled index_attempts at a time
|
||||
if (
|
||||
last_index.status == IndexingStatus.NOT_STARTED
|
||||
or last_index.status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
|
||||
|
||||
def _mark_run_failed(
|
||||
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
) -> None:
|
||||
"""Marks the `index_attempt` row as failed + updates the `
|
||||
connector_credential_pair` to reflect that the run failed"""
|
||||
logger.warning(
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
db_session=db_session,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
|
||||
|
||||
"""Main funcs"""
|
||||
|
||||
|
||||
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
"""Creates new indexing jobs for each connector / credential pair which is:
|
||||
1. Enabled
|
||||
2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
ongoing: set[tuple[int | None, int]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
if attempt is None:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
||||
"indexing jobs"
|
||||
)
|
||||
continue
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.search_settings_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in all_connector_credential_pairs:
|
||||
for search_settings_instance in search_settings:
|
||||
# Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
if (cc_pair.id, search_settings_instance.id) in ongoing:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
|
||||
# clean up completed jobs
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for attempt_id, job in existing_jobs.items():
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
del existing_jobs_copy[attempt_id]
|
||||
|
||||
if not index_attempt:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
||||
"up indexing jobs"
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
index_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
or job.status == "error"
|
||||
):
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
|
||||
# clean up in-progress jobs that were never completed
|
||||
connectors = fetch_connectors(db_session)
|
||||
for connector in connectors:
|
||||
in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
connector.id, db_session
|
||||
)
|
||||
for index_attempt in in_progress_indexing_attempts:
|
||||
if index_attempt.id in existing_jobs:
|
||||
# If index attempt is canceled, stop the run
|
||||
if index_attempt.status == IndexingStatus.FAILED:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
# check to see if the job has been updated in last `timeout_hours` hours, if not
|
||||
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# on the fact that the `time_updated` field is constantly updated every
|
||||
# batch of documents indexed
|
||||
current_db_time = get_db_current_time(db_session=db_session)
|
||||
time_since_update = current_db_time - index_attempt.time_updated
|
||||
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason="Indexing run frozen - no updates in the last three hours. "
|
||||
"The run will be re-attempted at next scheduled indexing time.",
|
||||
)
|
||||
else:
|
||||
# If job isn't known, simply mark it as failed
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def kickoff_indexing_jobs(
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
client: Client | SimpleJobClient,
|
||||
secondary_client: Client | SimpleJobClient,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
with Session(engine) as db_session:
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.search_settings)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
if attempt.id not in existing_jobs
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
|
||||
|
||||
if not new_indexing_attempts:
|
||||
return existing_jobs
|
||||
|
||||
indexing_attempt_count = 0
|
||||
|
||||
for attempt, search_settings in new_indexing_attempts:
|
||||
use_secondary_index = (
|
||||
search_settings.status == IndexModelStatus.FUTURE
|
||||
if search_settings is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector_credential_pair.connector is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
)
|
||||
with Session(engine) as db_session:
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Connector is null"
|
||||
)
|
||||
continue
|
||||
if attempt.connector_credential_pair.credential is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
)
|
||||
with Session(engine) as db_session:
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Credential is null"
|
||||
)
|
||||
continue
|
||||
|
||||
if use_secondary_index:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
else:
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
|
||||
if run:
|
||||
if indexing_attempt_count == 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
|
||||
)
|
||||
|
||||
indexing_attempt_count += 1
|
||||
secondary_str = " (secondary index)" if use_secondary_index else ""
|
||||
logger.info(
|
||||
f"Indexing dispatched{secondary_str}: "
|
||||
f"attempt_id={attempt.id} "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.credential_id}'"
|
||||
)
|
||||
existing_jobs_copy[attempt.id] = run
|
||||
|
||||
if indexing_attempt_count > 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch results: "
|
||||
f"initial_pending={len(new_indexing_attempts)} "
|
||||
f"started={indexing_attempt_count} "
|
||||
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def update_loop(
|
||||
delay: int = 10,
|
||||
num_workers: int = NUM_INDEXING_WORKERS,
|
||||
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
if DASK_JOB_CLIENT_ENABLED:
|
||||
cluster_primary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
# there are warning about high memory usage + "Event loop unresponsive"
|
||||
# which are not relevant to us since our workers are expected to use a
|
||||
# lot of memory + involve CPU intensive tasks that will not relinquish
|
||||
# the event loop
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
cluster_secondary = LocalCluster(
|
||||
n_workers=num_secondary_workers,
|
||||
threads_per_worker=1,
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
client_primary = Client(cluster_primary)
|
||||
client_secondary = Client(cluster_secondary)
|
||||
if LOG_LEVEL.lower() == "debug":
|
||||
client_primary.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||
|
||||
if existing_jobs:
|
||||
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
|
||||
logger.debug(
|
||||
"Found existing indexing jobs: "
|
||||
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
check_index_swap(db_session)
|
||||
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
|
||||
create_indexing_jobs(existing_jobs=existing_jobs)
|
||||
existing_jobs = kickoff_indexing_jobs(
|
||||
existing_jobs=existing_jobs,
|
||||
client=client_primary,
|
||||
secondary_client=client_secondary,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run update due to {e}")
|
||||
sleep_time = delay - (time.time() - start)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
def update__main() -> None:
|
||||
set_is_ee_based_on_env_variable()
|
||||
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
|
||||
|
||||
logger.notice("Starting indexing service")
|
||||
update_loop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
update__main()
|
||||
# TODO(rkuo): delete after background indexing via celery is fully vetted
|
||||
# import logging
|
||||
# import time
|
||||
# from datetime import datetime
|
||||
# import dask
|
||||
# from dask.distributed import Client
|
||||
# from dask.distributed import Future
|
||||
# from distributed import LocalCluster
|
||||
# from sqlalchemy import text
|
||||
# from sqlalchemy.exc import ProgrammingError
|
||||
# from sqlalchemy.orm import Session
|
||||
# from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
# from danswer.background.indexing.job_client import SimpleJob
|
||||
# from danswer.background.indexing.job_client import SimpleJobClient
|
||||
# from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
# from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
# from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
# from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
# from danswer.configs.app_configs import MULTI_TENANT
|
||||
# from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
# from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
# from danswer.configs.constants import DocumentSource
|
||||
# from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
# from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
# from danswer.db.connector import fetch_connectors
|
||||
# from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
# from danswer.db.engine import get_db_current_time
|
||||
# from danswer.db.engine import get_session_with_tenant
|
||||
# from danswer.db.engine import get_sqlalchemy_engine
|
||||
# from danswer.db.engine import SqlEngine
|
||||
# from danswer.db.index_attempt import create_index_attempt
|
||||
# from danswer.db.index_attempt import get_index_attempt
|
||||
# from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
# from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
# from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
# from danswer.db.index_attempt import mark_attempt_failed
|
||||
# from danswer.db.models import ConnectorCredentialPair
|
||||
# from danswer.db.models import IndexAttempt
|
||||
# from danswer.db.models import IndexingStatus
|
||||
# from danswer.db.models import IndexModelStatus
|
||||
# from danswer.db.models import SearchSettings
|
||||
# from danswer.db.search_settings import get_current_search_settings
|
||||
# from danswer.db.search_settings import get_secondary_search_settings
|
||||
# from danswer.db.swap_index import check_index_swap
|
||||
# from danswer.document_index.vespa.index import VespaIndex
|
||||
# from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
# from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
# from danswer.utils.logger import setup_logger
|
||||
# from danswer.utils.variable_functionality import global_version
|
||||
# from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
# from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
# from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
# from shared_configs.configs import LOG_LEVEL
|
||||
# logger = setup_logger()
|
||||
# # If the indexing dies, it's most likely due to resource constraints,
|
||||
# # restarting just delays the eventual failure, not useful to the user
|
||||
# dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
||||
# _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
# "Stopped mid run, likely due to the background process being killed"
|
||||
# )
|
||||
# def _should_create_new_indexing(
|
||||
# cc_pair: ConnectorCredentialPair,
|
||||
# last_index: IndexAttempt | None,
|
||||
# search_settings_instance: SearchSettings,
|
||||
# secondary_index_building: bool,
|
||||
# db_session: Session,
|
||||
# ) -> bool:
|
||||
# connector = cc_pair.connector
|
||||
# # don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
# if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
# return False
|
||||
# # User can still manually create single indexing attempts via the UI for the
|
||||
# # currently in use index
|
||||
# if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
# if (
|
||||
# search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
# and secondary_index_building
|
||||
# ):
|
||||
# return False
|
||||
# # When switching over models, always index at least once
|
||||
# if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
# if last_index:
|
||||
# # No new index if the last index attempt succeeded
|
||||
# # Once is enough. The model will never be able to swap otherwise.
|
||||
# if last_index.status == IndexingStatus.SUCCESS:
|
||||
# return False
|
||||
# # No new index if the last index attempt is waiting to start
|
||||
# if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
# return False
|
||||
# # No new index if the last index attempt is running
|
||||
# if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
# return False
|
||||
# else:
|
||||
# if (
|
||||
# connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
# ): # Ingestion API
|
||||
# return False
|
||||
# return True
|
||||
# # If the connector is paused or is the ingestion API, don't index
|
||||
# # NOTE: during an embedding model switch over, the following logic
|
||||
# # is bypassed by the above check for a future model
|
||||
# if (
|
||||
# not cc_pair.status.is_active()
|
||||
# or connector.id == 0
|
||||
# or connector.source == DocumentSource.INGESTION_API
|
||||
# ):
|
||||
# return False
|
||||
# if not last_index:
|
||||
# return True
|
||||
# if connector.refresh_freq is None:
|
||||
# return False
|
||||
# # Only one scheduled/ongoing job per connector at a time
|
||||
# # this prevents cases where
|
||||
# # (1) the "latest" index_attempt is scheduled so we show
|
||||
# # that in the UI despite another index_attempt being in-progress
|
||||
# # (2) multiple scheduled index_attempts at a time
|
||||
# if (
|
||||
# last_index.status == IndexingStatus.NOT_STARTED
|
||||
# or last_index.status == IndexingStatus.IN_PROGRESS
|
||||
# ):
|
||||
# return False
|
||||
# current_db_time = get_db_current_time(db_session)
|
||||
# time_since_index = current_db_time - last_index.time_updated
|
||||
# return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
# def _mark_run_failed(
|
||||
# db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
# ) -> None:
|
||||
# """Marks the `index_attempt` row as failed + updates the `
|
||||
# connector_credential_pair` to reflect that the run failed"""
|
||||
# logger.warning(
|
||||
# f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
# f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
# )
|
||||
# mark_attempt_failed(
|
||||
# index_attempt=index_attempt,
|
||||
# db_session=db_session,
|
||||
# failure_reason=failure_reason,
|
||||
# )
|
||||
# """Main funcs"""
|
||||
# def create_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None
|
||||
# ) -> None:
|
||||
# """Creates new indexing jobs for each connector / credential pair which is:
|
||||
# 1. Enabled
|
||||
# 2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||
# 3. There is not already an ongoing indexing attempt for this pair
|
||||
# """
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# ongoing: set[tuple[int | None, int]] = set()
|
||||
# for attempt_id in existing_jobs:
|
||||
# attempt = get_index_attempt(
|
||||
# db_session=db_session, index_attempt_id=attempt_id
|
||||
# )
|
||||
# if attempt is None:
|
||||
# logger.error(
|
||||
# f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
||||
# "indexing jobs"
|
||||
# )
|
||||
# continue
|
||||
# ongoing.add(
|
||||
# (
|
||||
# attempt.connector_credential_pair_id,
|
||||
# attempt.search_settings_id,
|
||||
# )
|
||||
# )
|
||||
# # Get the primary search settings
|
||||
# primary_search_settings = get_current_search_settings(db_session)
|
||||
# search_settings = [primary_search_settings]
|
||||
# # Check for secondary search settings
|
||||
# secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
# if secondary_search_settings is not None:
|
||||
# # If secondary settings exist, add them to the list
|
||||
# search_settings.append(secondary_search_settings)
|
||||
# all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
# for cc_pair in all_connector_credential_pairs:
|
||||
# for search_settings_instance in search_settings:
|
||||
# # Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
# if (cc_pair.id, search_settings_instance.id) in ongoing:
|
||||
# continue
|
||||
# last_attempt = get_last_attempt_for_cc_pair(
|
||||
# cc_pair.id, search_settings_instance.id, db_session
|
||||
# )
|
||||
# if not _should_create_new_indexing(
|
||||
# cc_pair=cc_pair,
|
||||
# last_index=last_attempt,
|
||||
# search_settings_instance=search_settings_instance,
|
||||
# secondary_index_building=len(search_settings) > 1,
|
||||
# db_session=db_session,
|
||||
# ):
|
||||
# continue
|
||||
# create_index_attempt(
|
||||
# cc_pair.id, search_settings_instance.id, db_session
|
||||
# )
|
||||
# def cleanup_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob],
|
||||
# tenant_id: str | None,
|
||||
# timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
# ) -> dict[int, Future | SimpleJob]:
|
||||
# existing_jobs_copy = existing_jobs.copy()
|
||||
# # clean up completed jobs
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# for attempt_id, job in existing_jobs.items():
|
||||
# index_attempt = get_index_attempt(
|
||||
# db_session=db_session, index_attempt_id=attempt_id
|
||||
# )
|
||||
# # do nothing for ongoing jobs that haven't been stopped
|
||||
# if not job.done():
|
||||
# if not index_attempt:
|
||||
# continue
|
||||
# if not index_attempt.is_finished():
|
||||
# continue
|
||||
# if job.status == "error":
|
||||
# logger.error(job.exception())
|
||||
# job.release()
|
||||
# del existing_jobs_copy[attempt_id]
|
||||
# if not index_attempt:
|
||||
# logger.error(
|
||||
# f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
||||
# "up indexing jobs"
|
||||
# )
|
||||
# continue
|
||||
# if (
|
||||
# index_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
# or job.status == "error"
|
||||
# ):
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
# )
|
||||
# # clean up in-progress jobs that were never completed
|
||||
# try:
|
||||
# connectors = fetch_connectors(db_session)
|
||||
# for connector in connectors:
|
||||
# in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
# connector.id, db_session
|
||||
# )
|
||||
# for index_attempt in in_progress_indexing_attempts:
|
||||
# if index_attempt.id in existing_jobs:
|
||||
# # If index attempt is canceled, stop the run
|
||||
# if index_attempt.status == IndexingStatus.FAILED:
|
||||
# existing_jobs[index_attempt.id].cancel()
|
||||
# # check to see if the job has been updated in last `timeout_hours` hours, if not
|
||||
# # assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# # on the fact that the `time_updated` field is constantly updated every
|
||||
# # batch of documents indexed
|
||||
# current_db_time = get_db_current_time(db_session=db_session)
|
||||
# time_since_update = current_db_time - index_attempt.time_updated
|
||||
# if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
||||
# existing_jobs[index_attempt.id].cancel()
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason="Indexing run frozen - no updates in the last three hours. "
|
||||
# "The run will be re-attempted at next scheduled indexing time.",
|
||||
# )
|
||||
# else:
|
||||
# # If job isn't known, simply mark it as failed
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
# )
|
||||
# except ProgrammingError:
|
||||
# logger.debug(f"No Connector Table exists for: {tenant_id}")
|
||||
# return existing_jobs_copy
|
||||
# def kickoff_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob],
|
||||
# client: Client | SimpleJobClient,
|
||||
# secondary_client: Client | SimpleJobClient,
|
||||
# tenant_id: str | None,
|
||||
# ) -> dict[int, Future | SimpleJob]:
|
||||
# existing_jobs_copy = existing_jobs.copy()
|
||||
# current_session = get_session_with_tenant(tenant_id)
|
||||
# # Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# # Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
# with current_session as db_session:
|
||||
# # get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# # we must process attempts in a FIFO manner to prevent connector starvation
|
||||
# new_indexing_attempts = [
|
||||
# (attempt, attempt.search_settings)
|
||||
# for attempt in get_not_started_index_attempts(db_session)
|
||||
# if attempt.id not in existing_jobs
|
||||
# ]
|
||||
# logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
|
||||
# if not new_indexing_attempts:
|
||||
# return existing_jobs
|
||||
# indexing_attempt_count = 0
|
||||
# primary_client_full = False
|
||||
# secondary_client_full = False
|
||||
# for attempt, search_settings in new_indexing_attempts:
|
||||
# if primary_client_full and secondary_client_full:
|
||||
# break
|
||||
# use_secondary_index = (
|
||||
# search_settings.status == IndexModelStatus.FUTURE
|
||||
# if search_settings is not None
|
||||
# else False
|
||||
# )
|
||||
# if attempt.connector_credential_pair.connector is None:
|
||||
# logger.warning(
|
||||
# f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
# )
|
||||
# with current_session as db_session:
|
||||
# mark_attempt_failed(
|
||||
# attempt, db_session, failure_reason="Connector is null"
|
||||
# )
|
||||
# continue
|
||||
# if attempt.connector_credential_pair.credential is None:
|
||||
# logger.warning(
|
||||
# f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
# )
|
||||
# with current_session as db_session:
|
||||
# mark_attempt_failed(
|
||||
# attempt, db_session, failure_reason="Credential is null"
|
||||
# )
|
||||
# continue
|
||||
# if not use_secondary_index:
|
||||
# if not primary_client_full:
|
||||
# run = client.submit(
|
||||
# run_indexing_entrypoint,
|
||||
# attempt.id,
|
||||
# tenant_id,
|
||||
# attempt.connector_credential_pair_id,
|
||||
# global_version.is_ee_version(),
|
||||
# pure=False,
|
||||
# )
|
||||
# if not run:
|
||||
# primary_client_full = True
|
||||
# else:
|
||||
# if not secondary_client_full:
|
||||
# run = secondary_client.submit(
|
||||
# run_indexing_entrypoint,
|
||||
# attempt.id,
|
||||
# tenant_id,
|
||||
# attempt.connector_credential_pair_id,
|
||||
# global_version.is_ee_version(),
|
||||
# pure=False,
|
||||
# )
|
||||
# if not run:
|
||||
# secondary_client_full = True
|
||||
# if run:
|
||||
# if indexing_attempt_count == 0:
|
||||
# logger.info(
|
||||
# f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
|
||||
# )
|
||||
# indexing_attempt_count += 1
|
||||
# secondary_str = " (secondary index)" if use_secondary_index else ""
|
||||
# logger.info(
|
||||
# f"Indexing dispatched{secondary_str}: "
|
||||
# f"attempt_id={attempt.id} "
|
||||
# f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
# f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
# f"credentials='{attempt.connector_credential_pair.credential_id}'"
|
||||
# )
|
||||
# existing_jobs_copy[attempt.id] = run
|
||||
# if indexing_attempt_count > 0:
|
||||
# logger.info(
|
||||
# f"Indexing dispatch results: "
|
||||
# f"initial_pending={len(new_indexing_attempts)} "
|
||||
# f"started={indexing_attempt_count} "
|
||||
# f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
|
||||
# )
|
||||
# return existing_jobs_copy
|
||||
# def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
# if not MULTI_TENANT:
|
||||
# return [None]
|
||||
# with get_session_with_tenant(tenant_id="public") as session:
|
||||
# result = session.execute(
|
||||
# text(
|
||||
# """
|
||||
# SELECT schema_name
|
||||
# FROM information_schema.schemata
|
||||
# WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
||||
# )
|
||||
# )
|
||||
# tenant_ids = [row[0] for row in result]
|
||||
# valid_tenants = [
|
||||
# tenant
|
||||
# for tenant in tenant_ids
|
||||
# if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
# ]
|
||||
# return valid_tenants
|
||||
# def update_loop(
|
||||
# delay: int = 10,
|
||||
# num_workers: int = NUM_INDEXING_WORKERS,
|
||||
# num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
# ) -> None:
|
||||
# if not MULTI_TENANT:
|
||||
# # We can use this function as we are certain only the public schema exists
|
||||
# # (explicitly for the non-`MULTI_TENANT` case)
|
||||
# engine = get_sqlalchemy_engine()
|
||||
# with Session(engine) as db_session:
|
||||
# check_index_swap(db_session=db_session)
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# # So that the first time users aren't surprised by really slow speed of first
|
||||
# # batch of documents indexed
|
||||
# if search_settings.provider_type is None:
|
||||
# logger.notice("Running a first inference to warm up embedding model")
|
||||
# embedding_model = EmbeddingModel.from_db_model(
|
||||
# search_settings=search_settings,
|
||||
# server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
# server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
# )
|
||||
# warm_up_bi_encoder(
|
||||
# embedding_model=embedding_model,
|
||||
# )
|
||||
# logger.notice("First inference complete.")
|
||||
# client_primary: Client | SimpleJobClient
|
||||
# client_secondary: Client | SimpleJobClient
|
||||
# if DASK_JOB_CLIENT_ENABLED:
|
||||
# cluster_primary = LocalCluster(
|
||||
# n_workers=num_workers,
|
||||
# threads_per_worker=1,
|
||||
# silence_logs=logging.ERROR,
|
||||
# )
|
||||
# cluster_secondary = LocalCluster(
|
||||
# n_workers=num_secondary_workers,
|
||||
# threads_per_worker=1,
|
||||
# silence_logs=logging.ERROR,
|
||||
# )
|
||||
# client_primary = Client(cluster_primary)
|
||||
# client_secondary = Client(cluster_secondary)
|
||||
# if LOG_LEVEL.lower() == "debug":
|
||||
# client_primary.register_worker_plugin(ResourceLogger())
|
||||
# else:
|
||||
# client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
# client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
# existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
|
||||
# logger.notice("Startup complete. Waiting for indexing jobs...")
|
||||
# while True:
|
||||
# start = time.time()
|
||||
# start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
# logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||
# if existing_jobs:
|
||||
# logger.debug(
|
||||
# "Found existing indexing jobs: "
|
||||
# f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
|
||||
# )
|
||||
# try:
|
||||
# tenants = get_all_tenant_ids()
|
||||
# for tenant_id in tenants:
|
||||
# try:
|
||||
# logger.debug(
|
||||
# f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
|
||||
# )
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# index_to_expire = check_index_swap(db_session=db_session)
|
||||
# if index_to_expire and tenant_id and MULTI_TENANT:
|
||||
# VespaIndex.delete_entries_by_tenant_id(
|
||||
# tenant_id=tenant_id,
|
||||
# index_name=index_to_expire.index_name,
|
||||
# )
|
||||
# if not MULTI_TENANT:
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# if search_settings.provider_type is None:
|
||||
# logger.notice(
|
||||
# "Running a first inference to warm up embedding model"
|
||||
# )
|
||||
# embedding_model = EmbeddingModel.from_db_model(
|
||||
# search_settings=search_settings,
|
||||
# server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
# server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
# )
|
||||
# warm_up_bi_encoder(embedding_model=embedding_model)
|
||||
# logger.notice("First inference complete.")
|
||||
# tenant_jobs = existing_jobs.get(tenant_id, {})
|
||||
# tenant_jobs = cleanup_indexing_jobs(
|
||||
# existing_jobs=tenant_jobs, tenant_id=tenant_id
|
||||
# )
|
||||
# create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id)
|
||||
# tenant_jobs = kickoff_indexing_jobs(
|
||||
# existing_jobs=tenant_jobs,
|
||||
# client=client_primary,
|
||||
# secondary_client=client_secondary,
|
||||
# tenant_id=tenant_id,
|
||||
# )
|
||||
# existing_jobs[tenant_id] = tenant_jobs
|
||||
# except Exception as e:
|
||||
# logger.exception(
|
||||
# f"Failed to process tenant {tenant_id or 'default'}: {e}"
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.exception(f"Failed to run update due to {e}")
|
||||
# sleep_time = delay - (time.time() - start)
|
||||
# if sleep_time > 0:
|
||||
# time.sleep(sleep_time)
|
||||
# def update__main() -> None:
|
||||
# set_is_ee_based_on_env_variable()
|
||||
# # initialize the Postgres connection pool
|
||||
# SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
|
||||
# logger.notice("Starting indexing service")
|
||||
# update_loop()
|
||||
# if __name__ == "__main__":
|
||||
# update__main()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import re
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
@@ -33,7 +35,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
# Optional id at which we finish processing
|
||||
@@ -166,3 +168,31 @@ def reorganize_citations(
|
||||
new_citation_info[citation.citation_num] = citation
|
||||
|
||||
return new_answer, list(new_citation_info.values())
|
||||
|
||||
|
||||
def extract_headers(
|
||||
headers: dict[str, str] | Headers, pass_through_headers: list[str] | None
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Extract headers specified in pass_through_headers from input headers.
|
||||
Handles both dict and FastAPI Headers objects, accounting for lowercase keys.
|
||||
|
||||
Args:
|
||||
headers: Input headers as dict or Headers object.
|
||||
|
||||
Returns:
|
||||
dict: Filtered headers based on pass_through_headers.
|
||||
"""
|
||||
if not pass_through_headers:
|
||||
return {}
|
||||
|
||||
extracted_headers: dict[str, str] = {}
|
||||
for key in pass_through_headers:
|
||||
if key in headers:
|
||||
extracted_headers[key] = headers[key]
|
||||
else:
|
||||
# fastapi makes all header keys lowercase, handling that here
|
||||
lowercase_key = key.lower()
|
||||
if lowercase_key in headers:
|
||||
extracted_headers[lowercase_key] = headers[lowercase_key]
|
||||
return extracted_headers
|
||||
|
||||
@@ -6,7 +6,6 @@ from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Persona
|
||||
@@ -18,30 +17,32 @@ from danswer.db.persona import upsert_prompt
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
||||
def load_prompts_from_yaml(
|
||||
db_session: Session, prompts_yaml: str = PROMPTS_YAML
|
||||
) -> None:
|
||||
with open(prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_prompts = data.get("prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for prompt in all_prompts:
|
||||
upsert_prompt(
|
||||
user=None,
|
||||
prompt_id=prompt.get("id"),
|
||||
name=prompt["name"],
|
||||
description=prompt["description"].strip(),
|
||||
system_prompt=prompt["system"].strip(),
|
||||
task_prompt=prompt["task"].strip(),
|
||||
include_citations=prompt["include_citations"],
|
||||
datetime_aware=prompt.get("datetime_aware", True),
|
||||
default_prompt=True,
|
||||
personas=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
for prompt in all_prompts:
|
||||
upsert_prompt(
|
||||
user=None,
|
||||
prompt_id=prompt.get("id"),
|
||||
name=prompt["name"],
|
||||
description=prompt["description"].strip(),
|
||||
system_prompt=prompt["system"].strip(),
|
||||
task_prompt=prompt["task"].strip(),
|
||||
include_citations=prompt["include_citations"],
|
||||
datetime_aware=prompt.get("datetime_aware", True),
|
||||
default_prompt=True,
|
||||
personas=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_personas_from_yaml(
|
||||
db_session: Session,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
) -> None:
|
||||
@@ -49,117 +50,117 @@ def load_personas_from_yaml(
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_personas = data.get("personas", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
]
|
||||
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
doc_set_ids: list[int] | None = None
|
||||
if doc_sets:
|
||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||
else:
|
||||
doc_set_ids = None
|
||||
|
||||
prompt_ids: list[int] | None = None
|
||||
prompt_set_names = persona["prompts"]
|
||||
if prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] = [
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
doc_set_ids: list[int] | None = None
|
||||
if doc_sets:
|
||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||
else:
|
||||
doc_set_ids = None
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
prompt_ids: list[int] | None = None
|
||||
prompt_set_names = persona["prompts"]
|
||||
if prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] = [
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
.filter(ToolDBModel.name == "ImageGenerationTool")
|
||||
.first()
|
||||
)
|
||||
if image_gen_tool:
|
||||
tool_ids.append(image_gen_tool.id)
|
||||
|
||||
llm_model_provider_override = persona.get("llm_model_provider_override")
|
||||
llm_model_version_override = persona.get("llm_model_version_override")
|
||||
|
||||
# Set specific overrides for image generation persona
|
||||
if persona.get("image_generation"):
|
||||
llm_model_version_override = "gpt-4o"
|
||||
|
||||
existing_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.name == persona["name"])
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
.filter(ToolDBModel.name == "ImageGenerationTool")
|
||||
.first()
|
||||
)
|
||||
if image_gen_tool:
|
||||
tool_ids.append(image_gen_tool.id)
|
||||
|
||||
upsert_persona(
|
||||
user=None,
|
||||
persona_id=(-1 * p_id) if p_id is not None else None,
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
num_chunks=persona.get("num_chunks")
|
||||
if persona.get("num_chunks") is not None
|
||||
else default_chunks,
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
tool_ids=tool_ids,
|
||||
default_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority"),
|
||||
is_visible=existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible"),
|
||||
db_session=db_session,
|
||||
)
|
||||
llm_model_provider_override = persona.get("llm_model_provider_override")
|
||||
llm_model_version_override = persona.get("llm_model_version_override")
|
||||
|
||||
# Set specific overrides for image generation persona
|
||||
if persona.get("image_generation"):
|
||||
llm_model_version_override = "gpt-4o"
|
||||
|
||||
existing_persona = (
|
||||
db_session.query(Persona).filter(Persona.name == persona["name"]).first()
|
||||
)
|
||||
|
||||
upsert_persona(
|
||||
user=None,
|
||||
persona_id=(-1 * p_id) if p_id is not None else None,
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
num_chunks=persona.get("num_chunks")
|
||||
if persona.get("num_chunks") is not None
|
||||
else default_chunks,
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
tool_ids=tool_ids,
|
||||
builtin_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority"),
|
||||
is_visible=existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible"),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
|
||||
def load_input_prompts_from_yaml(
|
||||
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
|
||||
) -> None:
|
||||
with open(input_prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_input_prompts = data.get("input_prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
db_session: Session,
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(prompt_yaml)
|
||||
load_personas_from_yaml(personas_yaml)
|
||||
load_input_prompts_from_yaml(input_prompts_yaml)
|
||||
load_prompts_from_yaml(db_session, prompt_yaml)
|
||||
load_personas_from_yaml(db_session, personas_yaml)
|
||||
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
|
||||
|
||||
@@ -11,7 +11,6 @@ from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.graphing.models import GraphGenerationDisplay
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
@@ -49,8 +48,6 @@ class QADocsResponse(RetrievalDocs):
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
FINISHED = "finished"
|
||||
NEW_RESPONSE = "new_response"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
@@ -176,7 +173,6 @@ AnswerQuestionPossibleReturn = (
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| GraphGenerationDisplay
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
@@ -18,8 +18,10 @@ from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
@@ -74,16 +76,13 @@ from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||
from danswer.tools.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.graphing.graphing_tool import GraphingResponse
|
||||
from danswer.tools.graphing.graphing_tool import GraphingTool
|
||||
from danswer.tools.graphing.models import GraphGenerationDisplay
|
||||
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
@@ -106,6 +105,7 @@ from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.headers import header_dict_to_header_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
@@ -250,7 +250,6 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| GraphingResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
@@ -259,7 +258,6 @@ ChatPacket = (
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| GraphGenerationDisplay
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
)
|
||||
@@ -279,7 +277,9 @@ def stream_chat_message_objects(
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -451,6 +451,7 @@ def stream_chat_message_objects(
|
||||
chat_session=chat_session,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=enforce_chat_session_id_for_search_docs,
|
||||
)
|
||||
|
||||
# Generates full documents currently
|
||||
@@ -537,21 +538,8 @@ def stream_chat_message_objects(
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if (
|
||||
tool_cls.__name__ == CSVAnalysisTool.__name__
|
||||
and not latest_query_files
|
||||
):
|
||||
tool_dict[db_tool_model.id] = [CSVAnalysisTool()]
|
||||
|
||||
if (
|
||||
tool_cls.__name__ == GraphingTool.__name__
|
||||
and not latest_query_files
|
||||
):
|
||||
tool_dict[db_tool_model.id] = [GraphingTool(output_dir="output")]
|
||||
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
@@ -578,7 +566,26 @@ def stream_chat_message_objects(
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
img_generation_llm_config = llm.config
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm.config.api_key,
|
||||
api_base=llm.config.api_base,
|
||||
api_version=llm.config.api_version,
|
||||
)
|
||||
elif (
|
||||
llm.config.model_provider == "azure"
|
||||
and AZURE_DALLE_API_KEY is not None
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider="azure",
|
||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=AZURE_DALLE_API_KEY,
|
||||
api_base=AZURE_DALLE_API_BASE,
|
||||
api_version=AZURE_DALLE_API_VERSION,
|
||||
)
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
@@ -597,7 +604,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name=openai_provider.default_model_name,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
@@ -609,6 +616,7 @@ def stream_chat_message_objects(
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
model=img_generation_llm_config.model_name,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
@@ -622,27 +630,29 @@ def stream_chat_message_objects(
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
custom_headers=(db_tool_model.custom_headers or [])
|
||||
+ (
|
||||
header_dict_to_header_list(
|
||||
custom_tool_additional_headers or {}
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
@@ -688,185 +698,102 @@ def stream_chat_message_objects(
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
yielded_message_id_info = True
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
|
||||
break
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=(
|
||||
retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False
|
||||
),
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
db_citations = None
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if reference_db_search_docs:
|
||||
db_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
if tool_result is None:
|
||||
tool_call = None
|
||||
else:
|
||||
tool_call = ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments={
|
||||
k: v if not isinstance(v, bytes) else v.decode("utf-8")
|
||||
for k, v in tool_result.tool_args.items()
|
||||
},
|
||||
tool_result=tool_result.tool_result,
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=cast(
|
||||
QADocsResponse, qa_docs_response
|
||||
).rephrased_query
|
||||
if qa_docs_response is not None
|
||||
else None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=cast(MessageSpecificCitations, db_citations).citation_map
|
||||
if db_citations is not None
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message.id
|
||||
if user_message is not None
|
||||
else gen_ai_response_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yielded_message_id_info = False
|
||||
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
reference_db_search_docs = None
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
else:
|
||||
if not yielded_message_id_info:
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=gen_ai_response_message.id,
|
||||
reserved_assistant_message_id=reserved_message_id,
|
||||
)
|
||||
yielded_message_id_info = True
|
||||
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
elif packet.id == GRAPHING_RESPONSE_ID:
|
||||
graph_generation = cast(GraphingResponse, packet.response)
|
||||
yield graph_generation
|
||||
|
||||
# yield GraphGenerationDisplay(
|
||||
# file_id=graph_generation.extra_graph_display.file_id,
|
||||
# line_graph=graph_generation.extra_graph_display.line_graph,
|
||||
# )
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(
|
||||
CustomToolCallSummary, packet.response
|
||||
)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(f"Failed to process chat message: {error_msg}")
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
error_msg = str(e)
|
||||
yield StreamingError(error=error_msg)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
error_msg = str(e)
|
||||
stack_trace = traceback.format_exc()
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
@@ -887,8 +814,11 @@ def stream_chat_message_objects(
|
||||
)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
if answer.llm_answer == "":
|
||||
return
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
@@ -903,14 +833,18 @@ def stream_chat_message_objects(
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
if tool_result
|
||||
else None,
|
||||
tool_calls=(
|
||||
[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
@@ -935,6 +869,7 @@ def stream_chat_message(
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as db_session:
|
||||
@@ -944,6 +879,7 @@ def stream_chat_message(
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
|
||||
@@ -53,7 +53,6 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
@@ -116,10 +115,16 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
|
||||
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
|
||||
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
|
||||
|
||||
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
|
||||
|
||||
# The default below is for dockerized deployment
|
||||
VESPA_DEPLOYMENT_ZIP = (
|
||||
os.environ.get("VESPA_DEPLOYMENT_ZIP") or "/app/danswer/vespa-app.zip"
|
||||
)
|
||||
VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
|
||||
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
|
||||
|
||||
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||
try:
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
|
||||
@@ -138,6 +143,12 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
)
|
||||
POSTGRES_API_SERVER_POOL_OVERFLOW = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
|
||||
)
|
||||
# defaults to False
|
||||
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
|
||||
|
||||
@@ -159,10 +170,33 @@ REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
# Used by celery as broker and backend
|
||||
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15))
|
||||
REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
|
||||
os.environ.get("REDIS_DB_NUMBER_CELERY_RESULT_BACKEND", 14)
|
||||
)
|
||||
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
|
||||
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "CERT_NONE")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
|
||||
# will propagate to both our redis client as well as celery's redis client
|
||||
REDIS_HEALTH_CHECK_INTERVAL = int(os.environ.get("REDIS_HEALTH_CHECK_INTERVAL", 60))
|
||||
|
||||
# our redis client only, not celery's
|
||||
REDIS_POOL_MAX_CONNECTIONS = int(os.environ.get("REDIS_POOL_MAX_CONNECTIONS", 128))
|
||||
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
# should be one of "required", "optional", or "none"
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
|
||||
|
||||
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
|
||||
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#broker-pool-limit
|
||||
# Setting to None may help when there is a proxy in the way closing idle connections
|
||||
CELERY_BROKER_POOL_LIMIT_DEFAULT = 10
|
||||
try:
|
||||
CELERY_BROKER_POOL_LIMIT = int(
|
||||
os.environ.get("CELERY_BROKER_POOL_LIMIT", CELERY_BROKER_POOL_LIMIT_DEFAULT)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
@@ -240,6 +274,10 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
if ignored_tag
|
||||
]
|
||||
# Maximum size for Jira tickets in bytes (default: 100KB)
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
@@ -263,7 +301,7 @@ ALLOW_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
# This is the maximum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
|
||||
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
|
||||
)
|
||||
@@ -327,12 +365,10 @@ INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
|
||||
# exception without aborting the attempt.
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
|
||||
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
# File based Key Value store no longer used
|
||||
DYNAMIC_CONFIG_STORE = "PostgresBackedDynamicConfigStore"
|
||||
|
||||
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
# used to allow the background indexing jobs to use a different embedding
|
||||
# model server than the API server
|
||||
@@ -370,6 +406,11 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
os.environ.get("CUSTOM_ANSWER_VALIDITY_CONDITIONS", "[]")
|
||||
)
|
||||
|
||||
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "5")
|
||||
|
||||
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
@@ -381,3 +422,39 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
ENTERPRISE_EDITION_ENABLED = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Azure DALL-E Configurations
|
||||
AZURE_DALLE_API_VERSION = os.environ.get("AZURE_DALLE_API_VERSION")
|
||||
AZURE_DALLE_API_KEY = os.environ.get("AZURE_DALLE_API_KEY")
|
||||
AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE")
|
||||
AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")
|
||||
|
||||
|
||||
# Cloud configuration
|
||||
|
||||
# Multi-tenancy configuration
|
||||
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
|
||||
|
||||
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
|
||||
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
|
||||
|
||||
# Security and authentication
|
||||
SECRET_JWT_KEY = os.environ.get(
|
||||
"SECRET_JWT_KEY", ""
|
||||
) # Used for encryption of the JWT token for user's tenant context
|
||||
DATA_PLANE_SECRET = os.environ.get(
|
||||
"DATA_PLANE_SECRET", ""
|
||||
) # Used for secure communication between the control and data plane
|
||||
EXPECTED_API_KEY = os.environ.get(
|
||||
"EXPECTED_API_KEY", ""
|
||||
) # Additional security check for the control plane API
|
||||
|
||||
# API configuration
|
||||
CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
"CONTROL_PLANE_API_BASE_URL", "http://localhost:8082"
|
||||
)
|
||||
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import platform
|
||||
import socket
|
||||
from enum import auto
|
||||
from enum import Enum
|
||||
|
||||
@@ -29,14 +31,22 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
# Prefix used for all tenant ids
|
||||
TENANT_ID_PREFIX = "tenant_"
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
POSTGRES_CELERY_APP_NAME = "celery"
|
||||
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
|
||||
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
POSTGRES_DEFAULT_SCHEMA = "public"
|
||||
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
@@ -46,6 +56,7 @@ UNNAMED_KEY_PLACEHOLDER = "Unnamed"
|
||||
# Key-Value store keys
|
||||
KV_REINDEX_KEY = "needs_reindexing"
|
||||
KV_SEARCH_SETTINGS = "search_settings"
|
||||
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
|
||||
KV_USER_STORE_KEY = "INVITED_USERS"
|
||||
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
|
||||
KV_CRED_KEY = "credential_id_{}"
|
||||
@@ -62,6 +73,17 @@ KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
@@ -99,15 +121,21 @@ class DocumentSource(str, Enum):
|
||||
CLICKUP = "clickup"
|
||||
MEDIAWIKI = "mediawiki"
|
||||
WIKIPEDIA = "wikipedia"
|
||||
ASANA = "asana"
|
||||
S3 = "s3"
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
REINDEX = "reindex"
|
||||
PERSONA_SHARED = "persona_shared"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -132,6 +160,15 @@ class AuthType(str, Enum):
|
||||
OIDC = "oidc"
|
||||
SAML = "saml"
|
||||
|
||||
# google auth and basic
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class SessionType(str, Enum):
|
||||
CHAT = "Chat"
|
||||
SEARCH = "Search"
|
||||
SLACK = "Slack"
|
||||
|
||||
|
||||
class QAFeedbackType(str, Enum):
|
||||
LIKE = "like" # User likes the answer, used for metrics
|
||||
@@ -165,7 +202,6 @@ class FileOrigin(str, Enum):
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
OTHER = "other"
|
||||
GRAPH_GEN = "graph_gen"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
@@ -173,16 +209,23 @@ class PostgresAdvisoryLocks(Enum):
|
||||
|
||||
|
||||
class DanswerCeleryQueues:
|
||||
VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator"
|
||||
VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator"
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
|
||||
|
||||
class DanswerRedisLocks:
|
||||
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
|
||||
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
|
||||
PRUNING_LOCK_PREFIX = "da_lock:pruning"
|
||||
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
|
||||
|
||||
|
||||
class DanswerCeleryPriority(int, Enum):
|
||||
HIGHEST = 0
|
||||
@@ -190,3 +233,13 @@ class DanswerCeleryPriority(int, Enum):
|
||||
MEDIUM = auto()
|
||||
LOW = auto()
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
|
||||
else:
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore
|
||||
|
||||
22
backend/danswer/configs/tool_configs.py
Normal file
22
backend/danswer/configs/tool_configs.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
# if specified, will pass through request headers to the call to API calls made by custom tools
|
||||
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
|
||||
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(
|
||||
"CUSTOM_TOOL_PASS_THROUGH_HEADERS"
|
||||
)
|
||||
if _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW:
|
||||
try:
|
||||
CUSTOM_TOOL_PASS_THROUGH_HEADERS = json.loads(
|
||||
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW
|
||||
)
|
||||
except Exception:
|
||||
# need to import here to avoid circular imports
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
logger.error(
|
||||
"Failed to parse CUSTOM_TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
233
backend/danswer/connectors/asana/asana_api.py
Executable file
233
backend/danswer/connectors/asana/asana_api.py
Executable file
@@ -0,0 +1,233 @@
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
|
||||
import asana # type: ignore
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints
|
||||
class AsanaTask:
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
title: str,
|
||||
text: str,
|
||||
link: str,
|
||||
last_modified: datetime,
|
||||
project_gid: str,
|
||||
project_name: str,
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.text = text
|
||||
self.link = link
|
||||
self.last_modified = last_modified
|
||||
self.project_gid = project_gid
|
||||
self.project_name = project_name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}"
|
||||
|
||||
|
||||
class AsanaAPI:
|
||||
def __init__(
|
||||
self, api_token: str, workspace_gid: str, team_gid: str | None
|
||||
) -> None:
|
||||
self._user = None # type: ignore
|
||||
self.workspace_gid = workspace_gid
|
||||
self.team_gid = team_gid
|
||||
|
||||
self.configuration = asana.Configuration()
|
||||
self.api_client = asana.ApiClient(self.configuration)
|
||||
self.tasks_api = asana.TasksApi(self.api_client)
|
||||
self.stories_api = asana.StoriesApi(self.api_client)
|
||||
self.users_api = asana.UsersApi(self.api_client)
|
||||
self.project_api = asana.ProjectsApi(self.api_client)
|
||||
self.workspaces_api = asana.WorkspacesApi(self.api_client)
|
||||
|
||||
self.api_error_count = 0
|
||||
self.configuration.access_token = api_token
|
||||
self.task_count = 0
|
||||
|
||||
def get_tasks(
|
||||
self, project_gids: list[str] | None, start_date: str
|
||||
) -> Iterator[AsanaTask]:
|
||||
"""Get all tasks from the projects with the given gids that were modified since the given date.
|
||||
If project_gids is None, get all tasks from all projects in the workspace."""
|
||||
logger.info("Starting to fetch Asana projects")
|
||||
projects = self.project_api.get_projects(
|
||||
opts={
|
||||
"workspace": self.workspace_gid,
|
||||
"opt_fields": "gid,name,archived,modified_at",
|
||||
}
|
||||
)
|
||||
start_seconds = int(time.mktime(datetime.now().timetuple()))
|
||||
projects_list = []
|
||||
project_count = 0
|
||||
for project_info in projects:
|
||||
project_gid = project_info["gid"]
|
||||
if project_gids is None or project_gid in project_gids:
|
||||
projects_list.append(project_gid)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Skipping project: {project_gid} - not in accepted project_gids"
|
||||
)
|
||||
project_count += 1
|
||||
if project_count % 100 == 0:
|
||||
logger.info(f"Processed {project_count} projects")
|
||||
|
||||
logger.info(f"Found {len(projects_list)} projects to process")
|
||||
for project_gid in projects_list:
|
||||
for task in self._get_tasks_for_project(
|
||||
project_gid, start_date, start_seconds
|
||||
):
|
||||
yield task
|
||||
logger.info(f"Completed fetching {self.task_count} tasks from Asana")
|
||||
if self.api_error_count > 0:
|
||||
logger.warning(
|
||||
f"Encountered {self.api_error_count} API errors during task fetching"
|
||||
)
|
||||
|
||||
def _get_tasks_for_project(
|
||||
self, project_gid: str, start_date: str, start_seconds: int
|
||||
) -> Iterator[AsanaTask]:
|
||||
project = self.project_api.get_project(project_gid, opts={})
|
||||
if project["archived"]:
|
||||
logger.info(f"Skipping archived project: {project['name']} ({project_gid})")
|
||||
return []
|
||||
if not project["team"] or not project["team"]["gid"]:
|
||||
logger.info(
|
||||
f"Skipping project without a team: {project['name']} ({project_gid})"
|
||||
)
|
||||
return []
|
||||
if project["privacy_setting"] == "private":
|
||||
if self.team_gid and project["team"]["gid"] != self.team_gid:
|
||||
logger.info(
|
||||
f"Skipping private project not in configured team: {project['name']} ({project_gid})"
|
||||
)
|
||||
return []
|
||||
else:
|
||||
logger.info(
|
||||
f"Processing private project in configured team: {project['name']} ({project_gid})"
|
||||
)
|
||||
|
||||
simple_start_date = start_date.split(".")[0].split("+")[0]
|
||||
logger.info(
|
||||
f"Fetching tasks modified since {simple_start_date} for project: {project['name']} ({project_gid})"
|
||||
)
|
||||
|
||||
opts = {
|
||||
"opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at,"
|
||||
"created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes,"
|
||||
"modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on,"
|
||||
"workspace,permalink_url",
|
||||
"modified_since": start_date,
|
||||
}
|
||||
tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts)
|
||||
for data in tasks_from_api:
|
||||
self.task_count += 1
|
||||
if self.task_count % 10 == 0:
|
||||
end_seconds = time.mktime(datetime.now().timetuple())
|
||||
runtime_seconds = end_seconds - start_seconds
|
||||
if runtime_seconds > 0:
|
||||
logger.info(
|
||||
f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds "
|
||||
f"({self.task_count / runtime_seconds:.2f} tasks/second)"
|
||||
)
|
||||
|
||||
logger.debug(f"Processing Asana task: {data['name']}")
|
||||
|
||||
text = self._construct_task_text(data)
|
||||
|
||||
try:
|
||||
text += self._fetch_and_add_comments(data["gid"])
|
||||
|
||||
last_modified_date = self.format_date(data["modified_at"])
|
||||
text += f"Last modified: {last_modified_date}\n"
|
||||
|
||||
task = AsanaTask(
|
||||
id=data["gid"],
|
||||
title=data["name"],
|
||||
text=text,
|
||||
link=data["permalink_url"],
|
||||
last_modified=datetime.fromisoformat(data["modified_at"]),
|
||||
project_gid=project_gid,
|
||||
project_name=project["name"],
|
||||
)
|
||||
yield task
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error processing task {data['gid']} in project {project_gid}",
|
||||
exc_info=True,
|
||||
)
|
||||
self.api_error_count += 1
|
||||
|
||||
def _construct_task_text(self, data: Dict) -> str:
|
||||
text = f"{data['name']}\n\n"
|
||||
|
||||
if data["notes"]:
|
||||
text += f"{data['notes']}\n\n"
|
||||
|
||||
if data["created_by"] and data["created_by"]["gid"]:
|
||||
creator = self.get_user(data["created_by"]["gid"])["name"]
|
||||
created_date = self.format_date(data["created_at"])
|
||||
text += f"Created by: {creator} on {created_date}\n"
|
||||
|
||||
if data["due_on"]:
|
||||
due_date = self.format_date(data["due_on"])
|
||||
text += f"Due date: {due_date}\n"
|
||||
|
||||
if data["completed_at"]:
|
||||
completed_date = self.format_date(data["completed_at"])
|
||||
text += f"Completed on: {completed_date}\n"
|
||||
|
||||
text += "\n"
|
||||
return text
|
||||
|
||||
def _fetch_and_add_comments(self, task_gid: str) -> str:
|
||||
text = ""
|
||||
stories_opts: Dict[str, str] = {}
|
||||
story_start = time.time()
|
||||
stories = self.stories_api.get_stories_for_task(task_gid, stories_opts)
|
||||
|
||||
story_count = 0
|
||||
comment_count = 0
|
||||
|
||||
for story in stories:
|
||||
story_count += 1
|
||||
if story["resource_subtype"] == "comment_added":
|
||||
comment = self.stories_api.get_story(
|
||||
story["gid"], opts={"opt_fields": "text,created_by,created_at"}
|
||||
)
|
||||
commenter = self.get_user(comment["created_by"]["gid"])["name"]
|
||||
text += f"Comment by {commenter}: {comment['text']}\n\n"
|
||||
comment_count += 1
|
||||
|
||||
story_duration = time.time() - story_start
|
||||
logger.debug(
|
||||
f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return text
|
||||
|
||||
def get_user(self, user_gid: str) -> Dict:
|
||||
if self._user is not None:
|
||||
return self._user
|
||||
self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"})
|
||||
|
||||
if not self._user:
|
||||
logger.warning(f"Unable to fetch user information for user_gid: {user_gid}")
|
||||
return {"name": "Unknown"}
|
||||
return self._user
|
||||
|
||||
def format_date(self, date_str: str) -> str:
|
||||
date = datetime.fromisoformat(date_str)
|
||||
return time.strftime("%Y-%m-%d", date.timetuple())
|
||||
|
||||
def get_time(self) -> str:
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
120
backend/danswer/connectors/asana/connector.py
Executable file
120
backend/danswer/connectors/asana/connector.py
Executable file
@@ -0,0 +1,120 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.asana import asana_api
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class AsanaConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
asana_workspace_id: str,
|
||||
asana_project_ids: str | None = None,
|
||||
asana_team_id: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.workspace_id = asana_workspace_id
|
||||
self.project_ids_to_index: list[str] | None = (
|
||||
asana_project_ids.split(",") if asana_project_ids is not None else None
|
||||
)
|
||||
self.asana_team_id = asana_team_id
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
logger.info(
|
||||
f"AsanaConnector initialized with workspace_id: {asana_workspace_id}"
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.api_token = credentials["asana_api_token_secret"]
|
||||
self.asana_client = asana_api.AsanaAPI(
|
||||
api_token=self.api_token,
|
||||
workspace_gid=self.workspace_id,
|
||||
team_gid=self.asana_team_id,
|
||||
)
|
||||
logger.info("Asana credentials loaded and API client initialized")
|
||||
return None
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_time = datetime.datetime.fromtimestamp(start).isoformat()
|
||||
logger.info(f"Starting Asana poll from {start_time}")
|
||||
asana = asana_api.AsanaAPI(
|
||||
api_token=self.api_token,
|
||||
workspace_gid=self.workspace_id,
|
||||
team_gid=self.asana_team_id,
|
||||
)
|
||||
docs_batch: list[Document] = []
|
||||
tasks = asana.get_tasks(self.project_ids_to_index, start_time)
|
||||
|
||||
for task in tasks:
|
||||
doc = self._message_to_doc(task)
|
||||
docs_batch.append(doc)
|
||||
|
||||
if len(docs_batch) >= self.batch_size:
|
||||
logger.info(f"Yielding batch of {len(docs_batch)} documents")
|
||||
yield docs_batch
|
||||
docs_batch = []
|
||||
|
||||
if docs_batch:
|
||||
logger.info(f"Yielding final batch of {len(docs_batch)} documents")
|
||||
yield docs_batch
|
||||
|
||||
logger.info("Asana poll completed")
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
logger.notice("Starting full index of all Asana tasks")
|
||||
return self.poll_source(start=0, end=None)
|
||||
|
||||
def _message_to_doc(self, task: asana_api.AsanaTask) -> Document:
|
||||
logger.debug(f"Converting Asana task {task.id} to Document")
|
||||
return Document(
|
||||
id=task.id,
|
||||
sections=[Section(link=task.link, text=task.text)],
|
||||
doc_updated_at=task.last_modified,
|
||||
source=DocumentSource.ASANA,
|
||||
semantic_identifier=task.title,
|
||||
metadata={
|
||||
"group": task.project_gid,
|
||||
"project": task.project_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
import os
|
||||
|
||||
logger.notice("Starting Asana connector test")
|
||||
connector = AsanaConnector(
|
||||
os.environ["WORKSPACE_ID"],
|
||||
os.environ["PROJECT_IDS"],
|
||||
os.environ["TEAM_ID"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"asana_api_token_secret": os.environ["API_TOKEN"],
|
||||
}
|
||||
)
|
||||
logger.info("Loading all documents from Asana")
|
||||
all_docs = connector.load_from_state()
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
logger.info("Polling for documents updated in the last 24 hours")
|
||||
latest_docs = connector.poll_source(one_day_ago, current)
|
||||
for docs in latest_docs:
|
||||
for doc in docs:
|
||||
print(doc.id)
|
||||
logger.notice("Asana connector test completed")
|
||||
@@ -194,8 +194,8 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
try:
|
||||
text = extract_file_text(
|
||||
name,
|
||||
BytesIO(downloaded_file),
|
||||
file_name=name,
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
batch.append(
|
||||
|
||||
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import bs4
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
Args:
|
||||
base_url (str): The base url of the Confluence instance
|
||||
content_url (str): The url of the page or attachment download url
|
||||
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def get_used_attachments(text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
|
||||
Returns:
|
||||
list[str]: List of filenames currently in use by the page text
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
@@ -6,7 +6,8 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
@@ -22,13 +23,16 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.confluence.confluence_utils import (
|
||||
build_confluence_document_id,
|
||||
)
|
||||
from danswer.connectors.confluence.confluence_utils import get_used_attachments
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
@@ -52,8 +56,40 @@ NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
|
||||
)
|
||||
|
||||
|
||||
class DanswerConfluence(Confluence):
|
||||
"""
|
||||
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
|
||||
This is necessary because the default Confluence class does not properly support cql expansions.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
|
||||
super(DanswerConfluence, self).__init__(url, *args, **kwargs)
|
||||
|
||||
def danswer_cql(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
cursor: str | None = None,
|
||||
limit: int = 500,
|
||||
include_archived_spaces: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
url_suffix = f"rest/api/content/search?cql={cql}"
|
||||
if expand:
|
||||
url_suffix += f"&expand={expand}"
|
||||
if cursor:
|
||||
url_suffix += f"&cursor={cursor}"
|
||||
url_suffix += f"&limit={limit}"
|
||||
if include_archived_spaces:
|
||||
url_suffix += "&includeArchivedSpaces=true"
|
||||
try:
|
||||
response = self.get(url_suffix)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _get_user(user_id: str, confluence_client: Confluence) -> str:
|
||||
def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
@@ -69,6 +105,7 @@ def _get_user(user_id: str, confluence_client: Confluence) -> str:
|
||||
confluence_client.get_user_details_by_accountid
|
||||
)
|
||||
try:
|
||||
logger.info(f"_get_user - get_user_details_by_accountid: id={user_id}")
|
||||
return get_user_details_by_accountid(user_id).get("displayName", user_not_found)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
@@ -77,7 +114,7 @@ def _get_user(user_id: str, confluence_client: Confluence) -> str:
|
||||
return user_not_found
|
||||
|
||||
|
||||
def parse_html_page(text: str, confluence_client: Confluence) -> str:
|
||||
def parse_html_page(text: str, confluence_client: DanswerConfluence) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
@@ -105,28 +142,10 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def get_used_attachments(text: str, confluence_client: Confluence) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
Returns:
|
||||
list[str]: List of filename currently in used
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
|
||||
|
||||
def _comment_dfs(
|
||||
comments_str: str,
|
||||
comment_pages: Collection[dict[str, Any]],
|
||||
confluence_client: Confluence,
|
||||
confluence_client: DanswerConfluence,
|
||||
) -> str:
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_child_by_type
|
||||
@@ -138,6 +157,9 @@ def _comment_dfs(
|
||||
comment_html, confluence_client
|
||||
)
|
||||
try:
|
||||
logger.info(
|
||||
f"_comment_dfs - get_page_by_child_type: id={comment_page['id']}"
|
||||
)
|
||||
child_comment_pages = get_page_child_by_type(
|
||||
comment_page["id"],
|
||||
type="comment",
|
||||
@@ -177,130 +199,103 @@ class RecursiveIndexer:
|
||||
index_recursively: bool,
|
||||
origin_page_id: str,
|
||||
) -> None:
|
||||
self.batch_size = 1
|
||||
# batch_size
|
||||
self.batch_size = batch_size
|
||||
self.confluence_client = confluence_client
|
||||
self.index_recursively = index_recursively
|
||||
self.origin_page_id = origin_page_id
|
||||
self.pages = self.recurse_children_pages(0, self.origin_page_id)
|
||||
self.pages = self.recurse_children_pages(self.origin_page_id)
|
||||
|
||||
def get_origin_page(self) -> list[dict[str, Any]]:
|
||||
return [self._fetch_origin_page()]
|
||||
|
||||
def get_pages(self, ind: int, size: int) -> list[dict]:
|
||||
if ind * size > len(self.pages):
|
||||
return []
|
||||
return self.pages[ind * size : (ind + 1) * size]
|
||||
def get_pages(self) -> list[dict[str, Any]]:
|
||||
return self.pages
|
||||
|
||||
def _fetch_origin_page(
|
||||
self,
|
||||
) -> dict[str, Any]:
|
||||
def _fetch_origin_page(self) -> dict[str, Any]:
|
||||
get_page_by_id = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_by_id
|
||||
)
|
||||
try:
|
||||
logger.info(
|
||||
f"_fetch_origin_page - get_page_by_id: id={self.origin_page_id}"
|
||||
)
|
||||
origin_page = get_page_by_id(
|
||||
self.origin_page_id, expand="body.storage.value,version"
|
||||
self.origin_page_id, expand="body.storage.value,version,space"
|
||||
)
|
||||
return origin_page
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Appending orgin page with id {self.origin_page_id} failed: {e}"
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Appending origin page with id {self.origin_page_id} failed."
|
||||
)
|
||||
return {}
|
||||
|
||||
def recurse_children_pages(
|
||||
self,
|
||||
start_ind: int,
|
||||
page_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
pages: list[dict[str, Any]] = []
|
||||
current_level_pages: list[dict[str, Any]] = []
|
||||
next_level_pages: list[dict[str, Any]] = []
|
||||
queue: list[str] = [page_id]
|
||||
visited_pages: set[str] = set()
|
||||
|
||||
# Initial fetch of first level children
|
||||
index = start_ind
|
||||
while batch := self._fetch_single_depth_child_pages(
|
||||
index, self.batch_size, page_id
|
||||
):
|
||||
current_level_pages.extend(batch)
|
||||
index += len(batch)
|
||||
|
||||
pages.extend(current_level_pages)
|
||||
|
||||
# Recursively index children and children's children, etc.
|
||||
while current_level_pages:
|
||||
for child in current_level_pages:
|
||||
child_index = 0
|
||||
while child_batch := self._fetch_single_depth_child_pages(
|
||||
child_index, self.batch_size, child["id"]
|
||||
):
|
||||
next_level_pages.extend(child_batch)
|
||||
child_index += len(child_batch)
|
||||
|
||||
pages.extend(next_level_pages)
|
||||
current_level_pages = next_level_pages
|
||||
next_level_pages = []
|
||||
|
||||
try:
|
||||
origin_page = self._fetch_origin_page()
|
||||
pages.append(origin_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
|
||||
|
||||
return pages
|
||||
|
||||
def _fetch_single_depth_child_pages(
|
||||
self, start_ind: int, batch_size: int, page_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
child_pages: list[dict[str, Any]] = []
|
||||
get_page_by_id = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_by_id
|
||||
)
|
||||
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
try:
|
||||
child_page = get_page_child_by_type(
|
||||
page_id,
|
||||
type="page",
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
while queue:
|
||||
current_page_id = queue.pop(0)
|
||||
if current_page_id in visited_pages:
|
||||
continue
|
||||
visited_pages.add(current_page_id)
|
||||
|
||||
child_pages.extend(child_page)
|
||||
return child_pages
|
||||
try:
|
||||
# Fetch the page itself
|
||||
logger.info(
|
||||
f"recurse_children_pages - get_page_by_id: id={current_page_id}"
|
||||
)
|
||||
page = get_page_by_id(
|
||||
current_page_id, expand="body.storage.value,version,space"
|
||||
)
|
||||
pages.append(page)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to fetch page {current_page_id}.")
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Batch failed with page {page_id} at offset {start_ind} "
|
||||
f"with size {batch_size}, processing pages individually..."
|
||||
)
|
||||
if not self.index_recursively:
|
||||
continue
|
||||
|
||||
for i in range(batch_size):
|
||||
ind = start_ind + i
|
||||
try:
|
||||
child_page = get_page_child_by_type(
|
||||
page_id,
|
||||
type="page",
|
||||
start=ind,
|
||||
limit=1,
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
child_pages.extend(child_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Page {page_id} at offset {ind} failed: {e}")
|
||||
raise e
|
||||
# Fetch child pages
|
||||
start = 0
|
||||
while True:
|
||||
logger.info(
|
||||
f"recurse_children_pages - get_page_by_child_type: id={current_page_id}"
|
||||
)
|
||||
child_pages_response = get_page_child_by_type(
|
||||
current_page_id,
|
||||
type="page",
|
||||
start=start,
|
||||
limit=self.batch_size,
|
||||
expand="",
|
||||
)
|
||||
if not child_pages_response:
|
||||
break
|
||||
for child_page in child_pages_response:
|
||||
child_page_id = child_page["id"]
|
||||
queue.append(child_page_id)
|
||||
start += len(child_pages_response)
|
||||
|
||||
return child_pages
|
||||
return pages
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_base: str,
|
||||
space: str,
|
||||
is_cloud: bool,
|
||||
space: str = "",
|
||||
page_id: str = "",
|
||||
index_recursively: bool = True,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -309,104 +304,167 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
cql_query: str | None = None,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.recursive_indexer: RecursiveIndexer | None = None
|
||||
self.index_recursively = index_recursively
|
||||
self.index_recursively = False if cql_query else index_recursively
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
self.space = space
|
||||
self.page_id = page_id
|
||||
self.page_id = "" if cql_query else page_id
|
||||
self.space_level_scan = bool(not self.page_id)
|
||||
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
self.space_level_scan = False
|
||||
self.confluence_client: Confluence | None = None
|
||||
self.confluence_client: DanswerConfluence | None = None
|
||||
|
||||
if self.page_id is None or self.page_id == "":
|
||||
self.space_level_scan = True
|
||||
# if a cql_query is provided, we will use it to fetch the pages
|
||||
# if no cql_query is provided, we will use the space to fetch the pages
|
||||
# if no space is provided and no cql_query, we will default to fetching all pages, regardless of space
|
||||
if cql_query:
|
||||
self.cql_query = cql_query
|
||||
elif space:
|
||||
self.cql_query = f"type=page and space='{space}'"
|
||||
else:
|
||||
self.cql_query = "type=page"
|
||||
|
||||
logger.info(
|
||||
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
|
||||
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}"
|
||||
f"wiki_base: {self.wiki_base}, space: {space}, page_id: {self.page_id},"
|
||||
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively},"
|
||||
+ f" cql_query: {self.cql_query}"
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
username = credentials["confluence_username"]
|
||||
access_token = credentials["confluence_access_token"]
|
||||
self.confluence_client = Confluence(
|
||||
|
||||
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
|
||||
# for a list of other hidden constructor args
|
||||
self.confluence_client = DanswerConfluence(
|
||||
url=self.wiki_base,
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=username if self.is_cloud else None,
|
||||
password=access_token if self.is_cloud else None,
|
||||
token=access_token if not self.is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=60,
|
||||
max_backoff_seconds=60,
|
||||
)
|
||||
return None
|
||||
|
||||
def _fetch_pages(
|
||||
self,
|
||||
confluence_client: Confluence,
|
||||
start_ind: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_all_pages_from_space
|
||||
cursor: str | None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
if self.confluence_client is None:
|
||||
raise Exception("Confluence client is not initialized")
|
||||
|
||||
def _fetch_space(
|
||||
cursor: str | None, batch_size: int
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
if not self.confluence_client:
|
||||
raise Exception("Confluence client is not initialized")
|
||||
get_all_pages = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.danswer_cql
|
||||
)
|
||||
|
||||
include_archived_spaces = (
|
||||
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
|
||||
if not self.is_cloud
|
||||
else False
|
||||
)
|
||||
|
||||
try:
|
||||
return get_all_pages_from_space(
|
||||
self.space,
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
status=(
|
||||
None if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES else "current"
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
logger.info(
|
||||
f"_fetch_space - get_all_pages: cursor={cursor} limit={batch_size}"
|
||||
)
|
||||
response = get_all_pages(
|
||||
cql=self.cql_query,
|
||||
cursor=cursor,
|
||||
limit=batch_size,
|
||||
expand="body.storage.value,version,space",
|
||||
include_archived_spaces=include_archived_spaces,
|
||||
)
|
||||
pages = response.get("results", [])
|
||||
next_cursor = None
|
||||
if "_links" in response and "next" in response["_links"]:
|
||||
next_link = response["_links"]["next"]
|
||||
parsed_url = urlparse(next_link)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
cursor_list = query_params.get("cursor", [])
|
||||
if cursor_list:
|
||||
next_cursor = cursor_list[0]
|
||||
return pages, next_cursor
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Batch failed with space {self.space} at offset {start_ind} "
|
||||
f"with size {batch_size}, processing pages individually..."
|
||||
f"Batch failed with cql {self.cql_query} with cursor {cursor} "
|
||||
f"and size {batch_size}, processing pages individually..."
|
||||
)
|
||||
|
||||
view_pages: list[dict[str, Any]] = []
|
||||
for i in range(self.batch_size):
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
# Could be that one of the pages here failed due to this bug:
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-76433
|
||||
view_pages.extend(
|
||||
get_all_pages_from_space(
|
||||
self.space,
|
||||
start=start_ind + i,
|
||||
limit=1,
|
||||
status=(
|
||||
None
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
|
||||
else "current"
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
logger.info(
|
||||
f"_fetch_space - get_all_pages: cursor={cursor} limit=1"
|
||||
)
|
||||
response = get_all_pages(
|
||||
cql=self.cql_query,
|
||||
cursor=cursor,
|
||||
limit=1,
|
||||
expand="body.view.value,version,space",
|
||||
include_archived_spaces=include_archived_spaces,
|
||||
)
|
||||
pages = response.get("results", [])
|
||||
view_pages.extend(pages)
|
||||
if "_links" in response and "next" in response["_links"]:
|
||||
next_link = response["_links"]["next"]
|
||||
parsed_url = urlparse(next_link)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
cursor_list = query_params.get("cursor", [])
|
||||
if cursor_list:
|
||||
cursor = cursor_list[0]
|
||||
else:
|
||||
cursor = None
|
||||
else:
|
||||
cursor = None
|
||||
break
|
||||
except HTTPError as e:
|
||||
logger.warning(
|
||||
f"Page failed with space {self.space} at offset {start_ind + i}, "
|
||||
f"Page failed with cql {self.cql_query} with cursor {cursor}, "
|
||||
f"trying alternative expand option: {e}"
|
||||
)
|
||||
# Use view instead, which captures most info but is less complete
|
||||
view_pages.extend(
|
||||
get_all_pages_from_space(
|
||||
self.space,
|
||||
start=start_ind + i,
|
||||
limit=1,
|
||||
expand="body.view.value,version",
|
||||
)
|
||||
logger.info(
|
||||
f"_fetch_space - get_all_pages - trying alternative expand: cursor={cursor} limit=1"
|
||||
)
|
||||
response = get_all_pages(
|
||||
cql=self.cql_query,
|
||||
cursor=cursor,
|
||||
limit=1,
|
||||
expand="body.view.value,version,space",
|
||||
)
|
||||
pages = response.get("results", [])
|
||||
view_pages.extend(pages)
|
||||
if "_links" in response and "next" in response["_links"]:
|
||||
next_link = response["_links"]["next"]
|
||||
parsed_url = urlparse(next_link)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
cursor_list = query_params.get("cursor", [])
|
||||
if cursor_list:
|
||||
cursor = cursor_list[0]
|
||||
else:
|
||||
cursor = None
|
||||
else:
|
||||
cursor = None
|
||||
break
|
||||
|
||||
return view_pages
|
||||
return view_pages, cursor
|
||||
|
||||
def _fetch_page() -> tuple[list[dict[str, Any]], str | None]:
|
||||
if self.confluence_client is None:
|
||||
raise Exception("Confluence client is not initialized")
|
||||
|
||||
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
if self.recursive_indexer is None:
|
||||
self.recursive_indexer = RecursiveIndexer(
|
||||
origin_page_id=self.page_id,
|
||||
@@ -415,41 +473,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
index_recursively=self.index_recursively,
|
||||
)
|
||||
|
||||
if self.index_recursively:
|
||||
return self.recursive_indexer.get_pages(start_ind, batch_size)
|
||||
else:
|
||||
return self.recursive_indexer.get_origin_page()
|
||||
|
||||
pages: list[dict[str, Any]] = []
|
||||
pages = self.recursive_indexer.get_pages()
|
||||
return pages, None # Since we fetched all pages, no cursor
|
||||
|
||||
try:
|
||||
pages = (
|
||||
_fetch_space(start_ind, self.batch_size)
|
||||
pages, next_cursor = (
|
||||
_fetch_space(cursor, self.batch_size)
|
||||
if self.space_level_scan
|
||||
else _fetch_page(start_ind, self.batch_size)
|
||||
else _fetch_page()
|
||||
)
|
||||
return pages
|
||||
|
||||
return pages, next_cursor
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
# error checking phase, only reachable if `self.continue_on_failure=True`
|
||||
for i in range(self.batch_size):
|
||||
try:
|
||||
pages = (
|
||||
_fetch_space(start_ind, self.batch_size)
|
||||
if self.space_level_scan
|
||||
else _fetch_page(start_ind, self.batch_size)
|
||||
)
|
||||
return pages
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Ran into exception when fetching pages from Confluence"
|
||||
)
|
||||
|
||||
return pages
|
||||
logger.exception("Ran into exception when fetching pages from Confluence")
|
||||
return [], None
|
||||
|
||||
def _fetch_comments(self, confluence_client: Confluence, page_id: str) -> str:
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
@@ -457,24 +496,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
try:
|
||||
comment_pages = cast(
|
||||
Collection[dict[str, Any]],
|
||||
logger.info(f"_fetch_comments - get_page_child_by_type: id={page_id}")
|
||||
comment_pages = list(
|
||||
get_page_child_by_type(
|
||||
page_id,
|
||||
type="comment",
|
||||
start=None,
|
||||
limit=None,
|
||||
expand="body.storage.value",
|
||||
),
|
||||
)
|
||||
)
|
||||
return _comment_dfs("", comment_pages, confluence_client)
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Ran into exception when fetching comments from Confluence"
|
||||
)
|
||||
logger.exception("Fetching comments from Confluence exceptioned")
|
||||
return ""
|
||||
|
||||
def _fetch_labels(self, confluence_client: Confluence, page_id: str) -> list[str]:
|
||||
@@ -482,13 +519,14 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
confluence_client.get_page_labels
|
||||
)
|
||||
try:
|
||||
logger.info(f"_fetch_labels - get_page_labels: id={page_id}")
|
||||
labels_response = get_page_labels(page_id)
|
||||
return [label["name"] for label in labels_response["results"]]
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
logger.exception("Ran into exception when fetching labels from Confluence")
|
||||
logger.exception("Fetching labels from Confluence exceptioned")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
@@ -525,6 +563,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
@@ -533,7 +572,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
attachment["title"], io.BytesIO(response.content), False
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
@@ -546,22 +587,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
return extracted_text
|
||||
|
||||
def _fetch_attachments(
|
||||
self, confluence_client: Confluence, page_id: str, files_in_used: list[str]
|
||||
self, confluence_client: Confluence, page_id: str, files_in_use: list[str]
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
unused_attachments: list = []
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
files_attachment_content: list[str] = []
|
||||
|
||||
get_attachments_from_content = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_attachments_from_content
|
||||
)
|
||||
files_attachment_content: list = []
|
||||
|
||||
try:
|
||||
expand = "history.lastUpdated,metadata.labels"
|
||||
attachments_container = get_attachments_from_content(
|
||||
page_id, start=0, limit=500, expand=expand
|
||||
page_id, start=None, limit=None, expand=expand
|
||||
)
|
||||
for attachment in attachments_container["results"]:
|
||||
if attachment["title"] not in files_in_used:
|
||||
for attachment in attachments_container.get("results", []):
|
||||
if attachment["title"] not in files_in_use:
|
||||
unused_attachments.append(attachment)
|
||||
continue
|
||||
|
||||
@@ -579,36 +620,33 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
f"User does not have access to attachments on page '{page_id}'"
|
||||
)
|
||||
return "", []
|
||||
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
logger.exception(
|
||||
f"Ran into exception when fetching attachments from Confluence: {e}"
|
||||
)
|
||||
logger.exception("Fetching attachments from Confluence exceptioned.")
|
||||
|
||||
return "\n".join(files_attachment_content), unused_attachments
|
||||
|
||||
def _get_doc_batch(
|
||||
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> tuple[list[Document], list[dict[str, Any]], int]:
|
||||
doc_batch: list[Document] = []
|
||||
self, cursor: str | None
|
||||
) -> tuple[list[Any], str | None, list[dict[str, Any]]]:
|
||||
if self.confluence_client is None:
|
||||
raise Exception("Confluence client is not initialized")
|
||||
|
||||
doc_batch: list[Any] = []
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
batch = self._fetch_pages(self.confluence_client, start_ind)
|
||||
batch, next_cursor = self._fetch_pages(cursor)
|
||||
|
||||
for page in batch:
|
||||
last_modified = _datetime_from_string(page["version"]["when"])
|
||||
author = cast(str | None, page["version"].get("by", {}).get("email"))
|
||||
|
||||
if time_filter and not time_filter(last_modified):
|
||||
continue
|
||||
author = page["version"].get("by", {}).get("email")
|
||||
|
||||
page_id = page["id"]
|
||||
|
||||
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
else:
|
||||
page_labels = []
|
||||
|
||||
# check disallowed labels
|
||||
if self.labels_to_skip:
|
||||
@@ -618,28 +656,32 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
f"Page with ID '{page_id}' has a label which has been "
|
||||
f"designated as disallowed: {label_intersection}. Skipping."
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
page_html = (
|
||||
page["body"].get("storage", page["body"].get("view", {})).get("value")
|
||||
)
|
||||
page_url = self.wiki_base + page["_links"]["webui"]
|
||||
# The url and the id are the same
|
||||
page_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"]
|
||||
)
|
||||
if not page_html:
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
|
||||
files_in_used = get_used_attachments(page_html, self.confluence_client)
|
||||
files_in_use = get_used_attachments(page_html)
|
||||
attachment_text, unused_page_attachments = self._fetch_attachments(
|
||||
self.confluence_client, page_id, files_in_used
|
||||
self.confluence_client, page_id, files_in_use
|
||||
)
|
||||
unused_attachments.extend(unused_page_attachments)
|
||||
|
||||
page_text += attachment_text
|
||||
page_text += "\n" + attachment_text if attachment_text else ""
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
|
||||
doc_metadata: dict[str, str | list[str]] = {
|
||||
"Wiki Space Name": page["space"]["name"]
|
||||
}
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
@@ -658,8 +700,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
return (
|
||||
doc_batch,
|
||||
next_cursor,
|
||||
unused_attachments,
|
||||
len(batch),
|
||||
)
|
||||
|
||||
def _get_attachment_batch(
|
||||
@@ -667,8 +709,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
start_ind: int,
|
||||
attachments: list[dict[str, Any]],
|
||||
time_filter: Callable[[datetime], bool] | None = None,
|
||||
) -> tuple[list[Document], int]:
|
||||
doc_batch: list[Document] = []
|
||||
) -> tuple[list[Any], int]:
|
||||
doc_batch: list[Any] = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
@@ -683,8 +725,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if time_filter and not time_filter(last_updated):
|
||||
continue
|
||||
|
||||
attachment_url = self._attachment_to_download_link(
|
||||
self.confluence_client, attachment
|
||||
# The url and the id are the same
|
||||
attachment_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["download"]
|
||||
)
|
||||
attachment_content = self._attachment_to_content(
|
||||
self.confluence_client, attachment
|
||||
@@ -695,7 +738,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
creator_email = attachment["history"]["createdBy"].get("email")
|
||||
|
||||
comment = attachment["metadata"].get("comment", "")
|
||||
doc_metadata: dict[str, str | list[str]] = {"comment": comment}
|
||||
doc_metadata: dict[str, Any] = {"comment": comment}
|
||||
|
||||
attachment_labels: list[str] = []
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
@@ -722,69 +765,36 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
return doc_batch, end_ind - start_ind
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
unused_attachments = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
|
||||
start_ind
|
||||
)
|
||||
unused_attachments.extend(unused_attachments_batch)
|
||||
start_ind += num_pages
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
if num_pages < self.batch_size:
|
||||
break
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
attachment_batch, num_attachments = self._get_attachment_batch(
|
||||
start_ind, unused_attachments
|
||||
)
|
||||
start_ind += num_attachments
|
||||
if attachment_batch:
|
||||
yield attachment_batch
|
||||
|
||||
if num_attachments < self.batch_size:
|
||||
break
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
def _handle_batch_retrieval(
|
||||
self,
|
||||
start: float | None = None,
|
||||
end: float | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
unused_attachments = []
|
||||
start_time = datetime.fromtimestamp(start, tz=timezone.utc) if start else None
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc) if end else None
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
start_ind = 0
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
cursor = None
|
||||
while True:
|
||||
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
|
||||
start_ind, time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
unused_attachments.extend(unused_attachments_batch)
|
||||
|
||||
start_ind += num_pages
|
||||
doc_batch, cursor, new_unused_attachments = self._get_doc_batch(cursor)
|
||||
unused_attachments.extend(new_unused_attachments)
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
if num_pages < self.batch_size:
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
# Process attachments if any
|
||||
start_ind = 0
|
||||
while True:
|
||||
attachment_batch, num_attachments = self._get_attachment_batch(
|
||||
start_ind,
|
||||
unused_attachments,
|
||||
time_filter=lambda t: start_time <= t <= end_time,
|
||||
start_ind=start_ind,
|
||||
attachments=unused_attachments,
|
||||
time_filter=(lambda t: start_time <= t <= end_time)
|
||||
if start_time and end_time
|
||||
else None,
|
||||
)
|
||||
|
||||
start_ind += num_attachments
|
||||
if attachment_batch:
|
||||
yield attachment_batch
|
||||
@@ -792,6 +802,12 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if num_attachments < self.batch_size:
|
||||
break
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._handle_batch_retrieval()
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
return self._handle_batch_retrieval(start=start, end=end)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = ConfluenceConnector(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -21,56 +22,198 @@ class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# commenting out while we try using confluence's rate limiter instead
|
||||
# # https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
# def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
# max_retries = 5
|
||||
# starting_delay = 5
|
||||
# backoff = 2
|
||||
|
||||
# # max_delay is used when the server doesn't hand back "Retry-After"
|
||||
# # and we have to decide the retry delay ourselves
|
||||
# max_delay = 30 # Atlassian uses max_delay = 30 in their examples
|
||||
|
||||
# # max_retry_after is used when we do get a "Retry-After" header
|
||||
# max_retry_after = 300 # should we really cap the maximum retry delay?
|
||||
|
||||
# NEXT_RETRY_KEY = BaseConnector.REDIS_KEY_PREFIX + "confluence_next_retry"
|
||||
|
||||
# # for testing purposes, rate limiting is written to fall back to a simpler
|
||||
# # rate limiting approach when redis is not available
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# for attempt in range(max_retries):
|
||||
# try:
|
||||
# # if multiple connectors are waiting for the next attempt, there could be an issue
|
||||
# # where many connectors are "released" onto the server at the same time.
|
||||
# # That's not ideal ... but coming up with a mechanism for queueing
|
||||
# # all of these connectors is a bigger problem that we want to take on
|
||||
# # right now
|
||||
# try:
|
||||
# next_attempt = r.get(NEXT_RETRY_KEY)
|
||||
# if next_attempt is None:
|
||||
# next_attempt = 0
|
||||
# else:
|
||||
# next_attempt = int(cast(int, next_attempt))
|
||||
|
||||
# # TODO: all connectors need to be interruptible moving forward
|
||||
# while time.monotonic() < next_attempt:
|
||||
# time.sleep(1)
|
||||
# except ConnectionError:
|
||||
# pass
|
||||
|
||||
# return confluence_call(*args, **kwargs)
|
||||
# except HTTPError as e:
|
||||
# # Check if the response or headers are None to avoid potential AttributeError
|
||||
# if e.response is None or e.response.headers is None:
|
||||
# logger.warning("HTTPError with `None` as response or as headers")
|
||||
# raise e
|
||||
|
||||
# retry_after_header = e.response.headers.get("Retry-After")
|
||||
# if (
|
||||
# e.response.status_code == 429
|
||||
# or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
# ):
|
||||
# retry_after = None
|
||||
# if retry_after_header is not None:
|
||||
# try:
|
||||
# retry_after = int(retry_after_header)
|
||||
# except ValueError:
|
||||
# pass
|
||||
|
||||
# if retry_after is not None:
|
||||
# if retry_after > max_retry_after:
|
||||
# logger.warning(
|
||||
# f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
|
||||
# )
|
||||
# retry_after = max_delay
|
||||
|
||||
# logger.warning(
|
||||
# f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
# )
|
||||
# try:
|
||||
# r.set(
|
||||
# NEXT_RETRY_KEY,
|
||||
# math.ceil(time.monotonic() + retry_after),
|
||||
# )
|
||||
# except ConnectionError:
|
||||
# pass
|
||||
# else:
|
||||
# logger.warning(
|
||||
# "Rate limit hit. Retrying with exponential backoff..."
|
||||
# )
|
||||
# delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
# delay_until = math.ceil(time.monotonic() + delay)
|
||||
|
||||
# try:
|
||||
# r.set(NEXT_RETRY_KEY, delay_until)
|
||||
# except ConnectionError:
|
||||
# while time.monotonic() < delay_until:
|
||||
# time.sleep(1)
|
||||
# else:
|
||||
# # re-raise, let caller handle
|
||||
# raise
|
||||
# except AttributeError as e:
|
||||
# # Some error within the Confluence library, unclear why it fails.
|
||||
# # Users reported it to be intermittent, so just retry
|
||||
# logger.warning(f"Confluence Internal Error, retrying... {e}")
|
||||
# delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
# delay_until = math.ceil(time.monotonic() + delay)
|
||||
# try:
|
||||
# r.set(NEXT_RETRY_KEY, delay_until)
|
||||
# except ConnectionError:
|
||||
# while time.monotonic() < delay_until:
|
||||
# time.sleep(1)
|
||||
|
||||
# if attempt == max_retries - 1:
|
||||
# raise e
|
||||
|
||||
# return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
max_retries = 5
|
||||
starting_delay = 5
|
||||
backoff = 2
|
||||
max_delay = 600
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 3600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
retry_after = None
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limit hit. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# re-raise, let caller handle
|
||||
raise
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
logger.warning(f"Confluence Internal Error, retrying... {e}")
|
||||
delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
time.sleep(delay)
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
@@ -9,6 +9,7 @@ from jira.resources import Issue
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@@ -134,10 +135,18 @@ def fetch_jira_issues_batch(
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{description}\n" + "\n".join(
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {jira.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
continue
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
people = set()
|
||||
@@ -180,7 +189,7 @@ def fetch_jira_issues_batch(
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=semantic_rep)],
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=jira.fields.summary,
|
||||
doc_updated_at=time_str_to_utc(jira.fields.updated),
|
||||
@@ -236,10 +245,12 @@ class JiraConnector(LoadConnector, PollConnector):
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=f"project = {self.jira_project}",
|
||||
jql=f"project = {quoted_project}",
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
@@ -267,8 +278,10 @@ class JiraConnector(LoadConnector, PollConnector):
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
jql = (
|
||||
f"project = {self.jira_project} AND "
|
||||
f"project = {quoted_project} AND "
|
||||
f"updated >= '{start_date_str}' AND "
|
||||
f"updated <= '{end_date_str}'"
|
||||
)
|
||||
|
||||
@@ -97,8 +97,8 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
link = self._get_shared_link(entry.path_display)
|
||||
try:
|
||||
text = extract_file_text(
|
||||
entry.name,
|
||||
BytesIO(downloaded_file),
|
||||
file_name=entry.name,
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
batch.append(
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Type
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from danswer.connectors.asana.connector import AsanaConnector
|
||||
from danswer.connectors.axero.connector import AxeroConnector
|
||||
from danswer.connectors.blob.connector import BlobStorageConnector
|
||||
from danswer.connectors.bookstack.connector import BookstackConnector
|
||||
@@ -41,6 +43,7 @@ from danswer.connectors.slack.load_connector import SlackLoadConnector
|
||||
from danswer.connectors.teams.connector import TeamsConnector
|
||||
from danswer.connectors.web.connector import WebConnector
|
||||
from danswer.connectors.wikipedia.connector import WikipediaConnector
|
||||
from danswer.connectors.xenforo.connector import XenforoConnector
|
||||
from danswer.connectors.zendesk.connector import ZendeskConnector
|
||||
from danswer.connectors.zulip.connector import ZulipConnector
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
@@ -61,6 +64,7 @@ def identify_connector_class(
|
||||
DocumentSource.SLACK: {
|
||||
InputType.LOAD_STATE: SlackLoadConnector,
|
||||
InputType.POLL: SlackPollConnector,
|
||||
InputType.PRUNE: SlackPollConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
@@ -91,10 +95,12 @@ def identify_connector_class(
|
||||
DocumentSource.CLICKUP: ClickupConnector,
|
||||
DocumentSource.MEDIAWIKI: MediaWikiConnector,
|
||||
DocumentSource.WIKIPEDIA: WikipediaConnector,
|
||||
DocumentSource.ASANA: AsanaConnector,
|
||||
DocumentSource.S3: BlobStorageConnector,
|
||||
DocumentSource.R2: BlobStorageConnector,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
@@ -124,13 +130,18 @@ def identify_connector_class(
|
||||
|
||||
|
||||
def instantiate_connector(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
if source in DocumentSourceRequiringTenantContext:
|
||||
connector_specific_config["tenant_id"] = tenant_id
|
||||
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
|
||||
@@ -10,13 +10,14 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
|
||||
from danswer.file_processing.extract_file_text import detect_encoding
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
@@ -27,6 +28,7 @@ from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -74,13 +76,14 @@ def _process_file(
|
||||
)
|
||||
|
||||
# Using the PDF reader function directly to pass in password cleanly
|
||||
elif extension == ".pdf":
|
||||
elif extension == ".pdf" and pdf_pass is not None:
|
||||
file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass)
|
||||
|
||||
else:
|
||||
file_content_raw = extract_file_text(
|
||||
file_name=file_name,
|
||||
file=file,
|
||||
file_name=file_name,
|
||||
break_on_unprocessable=True,
|
||||
)
|
||||
|
||||
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
|
||||
@@ -158,10 +161,12 @@ class LocalFileConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [Path(file_location) for file_location in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.tenant_id = tenant_id
|
||||
self.pdf_pass: str | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -170,7 +175,9 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
token = current_tenant_id.set(self.tenant_id)
|
||||
|
||||
with get_session_with_tenant(self.tenant_id) as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
@@ -192,6 +199,8 @@ class LocalFileConnector(LoadConnector):
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
|
||||
|
||||
@@ -25,7 +25,7 @@ from danswer.connectors.gmail.constants import (
|
||||
from danswer.connectors.gmail.constants import SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import GoogleAppCredentials
|
||||
from danswer.server.documents.models import GoogleServiceAccountKey
|
||||
@@ -72,7 +72,7 @@ def get_gmail_creds_for_service_account(
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Gmail Connector callback does not match expected"
|
||||
@@ -80,7 +80,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
|
||||
|
||||
|
||||
def get_gmail_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
@@ -92,14 +92,14 @@ def get_gmail_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
@@ -111,7 +111,7 @@ def get_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
@@ -158,42 +158,40 @@ def build_service_account_creds(
|
||||
|
||||
|
||||
def get_google_app_gmail_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
|
||||
|
||||
def delete_google_app_gmail_cred() -> None:
|
||||
get_dynamic_config_store().delete(KV_GMAIL_CRED_KEY)
|
||||
get_kv_store().delete(KV_GMAIL_CRED_KEY)
|
||||
|
||||
|
||||
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_gmail_service_account_key(
|
||||
service_account_key: GoogleServiceAccountKey,
|
||||
) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_gmail_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
@@ -6,7 +6,6 @@ from datetime import timezone
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
@@ -21,19 +20,13 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds_for_authorized_user,
|
||||
)
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds_for_service_account,
|
||||
)
|
||||
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -43,6 +36,8 @@ from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import docx_to_text
|
||||
from danswer.file_processing.extract_file_text import pptx_to_text
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -334,16 +329,24 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
elif mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
GDriveMimeType.PDF.value,
|
||||
]:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return text
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
if get_unstructured_api_key():
|
||||
return unstructured_to_text(
|
||||
file=io.BytesIO(response), file_name=file.get("name", file["id"])
|
||||
)
|
||||
|
||||
if mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return text
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
@@ -407,42 +410,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(
|
||||
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
||||
)
|
||||
creds = get_google_drive_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = creds.to_json() if creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
if DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
creds = get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
|
||||
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Google Drive - unknown credential structure."
|
||||
)
|
||||
|
||||
creds, new_creds_dict = get_google_drive_creds(credentials)
|
||||
self.creds = creds
|
||||
return new_creds_dict
|
||||
|
||||
@@ -494,8 +462,34 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
for permission in file["permissions"]
|
||||
):
|
||||
continue
|
||||
try:
|
||||
text_contents = extract_text(file, service) or ""
|
||||
except HttpError as e:
|
||||
reason = (
|
||||
e.error_details[0]["reason"]
|
||||
if e.error_details
|
||||
else e.reason
|
||||
)
|
||||
message = (
|
||||
e.error_details[0]["message"]
|
||||
if e.error_details
|
||||
else e.reason
|
||||
)
|
||||
|
||||
text_contents = extract_text(file, service) or ""
|
||||
# these errors don't represent a failure in the connector, but simply files
|
||||
# that can't / shouldn't be indexed
|
||||
ERRORS_TO_CONTINUE_ON = [
|
||||
"cannotExportFile",
|
||||
"exportSizeLimitExceeded",
|
||||
"cannotDownloadFile",
|
||||
]
|
||||
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
|
||||
logger.warning(
|
||||
f"Could not export file '{file['name']}' due to '{message}', skipping..."
|
||||
)
|
||||
continue
|
||||
|
||||
raise
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
@@ -509,6 +503,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
file["modifiedTime"]
|
||||
).astimezone(timezone.utc),
|
||||
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -10,11 +10,13 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_drive.constants import BASE_SCOPES
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
@@ -22,10 +24,11 @@ from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.google_drive.constants import SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import GoogleAppCredentials
|
||||
from danswer.server.documents.models import GoogleServiceAccountKey
|
||||
@@ -34,15 +37,25 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_gdrive_scopes() -> list[str]:
|
||||
base_scopes: list[str] = BASE_SCOPES
|
||||
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
|
||||
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
|
||||
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
return base_scopes + permissions_scopes + groups_scopes
|
||||
return base_scopes + permissions_scopes
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect() -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
|
||||
|
||||
def get_google_drive_creds_for_authorized_user(
|
||||
token_json_str: str,
|
||||
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> OAuthCredentials | None:
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
@@ -59,20 +72,69 @@ def get_google_drive_creds_for_authorized_user(
|
||||
return None
|
||||
|
||||
|
||||
def get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str: str,
|
||||
def _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> ServiceAccountCredentials | None:
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=SCOPES
|
||||
service_account_key, scopes=scopes
|
||||
)
|
||||
if not creds.valid or not creds.expired:
|
||||
creds.refresh(Request())
|
||||
return creds if creds.valid else None
|
||||
|
||||
|
||||
def get_google_drive_creds(
|
||||
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
oauth_creds = get_google_drive_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str, scopes=scopes
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
service_creds = _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
service_creds = (
|
||||
service_creds.with_subject(delegated_user_email)
|
||||
if service_creds
|
||||
else None
|
||||
)
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = (
|
||||
oauth_creds or service_creds
|
||||
)
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Google Drive - unknown credential structure."
|
||||
)
|
||||
|
||||
return creds, new_creds_dict
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
@@ -80,11 +142,11 @@ def verify_csrf(credential_id: int, state: str) -> None:
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=SCOPES,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
@@ -92,7 +154,7 @@ def get_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
@@ -107,7 +169,7 @@ def update_credential_access_tokens(
|
||||
app_credentials = get_google_app_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
scopes=SCOPES,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
flow.fetch_token(code=auth_code)
|
||||
@@ -140,32 +202,28 @@ def build_service_account_creds(
|
||||
|
||||
|
||||
def get_google_app_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
get_kv_store().store(KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
|
||||
|
||||
def delete_google_app_cred() -> None:
|
||||
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
|
||||
|
||||
def get_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(
|
||||
get_dynamic_config_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
]
|
||||
|
||||
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
|
||||
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]
|
||||
|
||||
@@ -11,6 +11,8 @@ GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC):
|
||||
REDIS_KEY_PREFIX = "da_connector_data:"
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -45,8 +45,7 @@ class FamilyFileGeneratorInMemory(generate_family_file.FamilyFileGenerator):
|
||||
|
||||
if any(x not in generate_family_file.NAME_CHARACTERS for x in name):
|
||||
raise ValueError(
|
||||
'ERROR: Name of family "{}" must be ASCII letters and digits [a-zA-Z0-9]',
|
||||
name,
|
||||
f'ERROR: Name of family "{name}" must be ASCII letters and digits [a-zA-Z0-9]',
|
||||
)
|
||||
|
||||
if isinstance(dointerwiki, bool):
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import datetime
|
||||
import itertools
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import ClassVar
|
||||
|
||||
@@ -19,6 +20,9 @@ from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.mediawiki.family import family_class_dispatch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def pywikibot_timestamp_to_utc_datetime(
|
||||
@@ -74,7 +78,7 @@ def get_doc_from_page(
|
||||
sections=sections,
|
||||
semantic_identifier=page.title(),
|
||||
metadata={"categories": [category.title() for category in page.categories()]},
|
||||
id=page.pageid,
|
||||
id=f"MEDIAWIKI_{page.pageid}_{page.full_url()}",
|
||||
)
|
||||
|
||||
|
||||
@@ -117,13 +121,18 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
|
||||
# short names can only have ascii letters and digits
|
||||
|
||||
self.family = family_class_dispatch(hostname, "Wikipedia Connector")()
|
||||
self.family = family_class_dispatch(hostname, "WikipediaConnector")()
|
||||
self.site = pywikibot.Site(fam=self.family, code=language_code)
|
||||
self.categories = [
|
||||
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
|
||||
for category in categories
|
||||
]
|
||||
self.pages = [pywikibot.Page(self.site, page) for page in pages]
|
||||
|
||||
self.pages = []
|
||||
for page in pages:
|
||||
if not page:
|
||||
continue
|
||||
self.pages.append(pywikibot.Page(self.site, page))
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load credentials for a MediaWiki site.
|
||||
@@ -169,8 +178,13 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
]
|
||||
|
||||
# Since we can specify both individual pages and categories, we need to iterate over all of them.
|
||||
all_pages = itertools.chain(self.pages, *category_pages)
|
||||
all_pages: Iterator[pywikibot.Page] = itertools.chain(
|
||||
self.pages, *category_pages
|
||||
)
|
||||
for page in all_pages:
|
||||
logger.info(
|
||||
f"MediaWikiConnector: title='{page.title()}' url={page.full_url()}"
|
||||
)
|
||||
doc_batch.append(
|
||||
get_doc_from_page(page, self.site, self.document_source_type)
|
||||
)
|
||||
|
||||
@@ -113,6 +113,9 @@ class DocumentBase(BaseModel):
|
||||
# The default title is semantic_identifier though unless otherwise specified
|
||||
title: str | None = None
|
||||
from_ingestion_api: bool = False
|
||||
# Anything else that may be useful that is specific to this particular connector type that other
|
||||
# parts of the code may need. If you're unsure, this can be left as None
|
||||
additional_info: Any = None
|
||||
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
|
||||
@@ -29,6 +29,9 @@ logger = setup_logger()
|
||||
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
|
||||
|
||||
|
||||
# TODO: Tables need to be ingested, Pages need to have their metadata ingested
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotionPage:
|
||||
"""Represents a Notion Page object"""
|
||||
@@ -40,6 +43,8 @@ class NotionPage:
|
||||
properties: dict[str, Any]
|
||||
url: str
|
||||
|
||||
database_name: str | None # Only applicable to the database type page (wiki)
|
||||
|
||||
def __init__(self, **kwargs: dict[str, Any]) -> None:
|
||||
names = set([f.name for f in fields(self)])
|
||||
for k, v in kwargs.items():
|
||||
@@ -47,6 +52,17 @@ class NotionPage:
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotionBlock:
|
||||
"""Represents a Notion Block object"""
|
||||
|
||||
id: str # Used for the URL
|
||||
text: str
|
||||
# In a plaintext representation of the page, how this block should be joined
|
||||
# with the existing text up to this point, separated out from text for clarity
|
||||
prefix: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotionSearchResponse:
|
||||
"""Represents the response from the Notion Search API"""
|
||||
@@ -62,7 +78,6 @@ class NotionSearchResponse:
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
# TODO - Add the ability to optionally limit to specific Notion databases
|
||||
class NotionConnector(LoadConnector, PollConnector):
|
||||
"""Notion Page connector that reads all Notion pages
|
||||
this integration has been granted access to.
|
||||
@@ -126,21 +141,47 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_page(self, page_id: str) -> NotionPage:
|
||||
"""Fetch a page from it's ID via the Notion API."""
|
||||
"""Fetch a page from its ID via the Notion API, retry with database if page fetch fails."""
|
||||
logger.debug(f"Fetching page for ID '{page_id}'")
|
||||
block_url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
page_url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
res = rl_requests.get(
|
||||
block_url,
|
||||
page_url,
|
||||
headers=self.headers,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error fetching page - {res.json()}")
|
||||
raise e
|
||||
logger.warning(
|
||||
f"Failed to fetch page, trying database for ID '{page_id}'. Exception: {e}"
|
||||
)
|
||||
# Try fetching as a database if page fetch fails, this happens if the page is set to a wiki
|
||||
# it becomes a database from the notion perspective
|
||||
return self._fetch_database_as_page(page_id)
|
||||
return NotionPage(**res.json())
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
|
||||
"""Attempt to fetch a database as a page."""
|
||||
logger.debug(f"Fetching database for ID '{database_id}' as a page")
|
||||
database_url = f"https://api.notion.com/v1/databases/{database_id}"
|
||||
res = rl_requests.get(
|
||||
database_url,
|
||||
headers=self.headers,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error fetching database as page - {res.json()}")
|
||||
raise e
|
||||
database_name = res.json().get("title")
|
||||
database_name = (
|
||||
database_name[0].get("text", {}).get("content") if database_name else None
|
||||
)
|
||||
|
||||
return NotionPage(**res.json(), database_name=database_name)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database(
|
||||
self, database_id: str, cursor: str | None = None
|
||||
@@ -171,8 +212,75 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
raise e
|
||||
return res.json()
|
||||
|
||||
def _read_pages_from_database(self, database_id: str) -> list[str]:
|
||||
"""Returns a list of all page IDs in the database"""
|
||||
@staticmethod
|
||||
def _properties_to_str(properties: dict[str, Any]) -> str:
|
||||
"""Converts Notion properties to a string"""
|
||||
|
||||
def _recurse_properties(inner_dict: dict[str, Any]) -> str | None:
|
||||
while "type" in inner_dict:
|
||||
type_name = inner_dict["type"]
|
||||
inner_dict = inner_dict[type_name]
|
||||
|
||||
# If the innermost layer is None, the value is not set
|
||||
if not inner_dict:
|
||||
return None
|
||||
|
||||
if isinstance(inner_dict, list):
|
||||
list_properties = [
|
||||
_recurse_properties(item) for item in inner_dict if item
|
||||
]
|
||||
return (
|
||||
", ".join(
|
||||
[
|
||||
list_property
|
||||
for list_property in list_properties
|
||||
if list_property
|
||||
]
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
# TODO there may be more types to handle here
|
||||
if "name" in inner_dict:
|
||||
return inner_dict["name"]
|
||||
if "content" in inner_dict:
|
||||
return inner_dict["content"]
|
||||
start = inner_dict.get("start")
|
||||
end = inner_dict.get("end")
|
||||
if start is not None:
|
||||
if end is not None:
|
||||
return f"{start} - {end}"
|
||||
return start
|
||||
elif end is not None:
|
||||
return f"Until {end}"
|
||||
|
||||
if "id" in inner_dict:
|
||||
# This is not useful to index, it's a reference to another Notion object
|
||||
# and this ID value in plaintext is useless outside of the Notion context
|
||||
logger.debug("Skipping Notion object id field property")
|
||||
return None
|
||||
|
||||
logger.debug(f"Unreadable property from innermost prop: {inner_dict}")
|
||||
return None
|
||||
|
||||
result = ""
|
||||
for prop_name, prop in properties.items():
|
||||
if not prop:
|
||||
continue
|
||||
|
||||
inner_value = _recurse_properties(prop)
|
||||
# Not a perfect way to format Notion database tables but there's no perfect representation
|
||||
# since this must be represented as plaintext
|
||||
if inner_value:
|
||||
result += f"{prop_name}: {inner_value}\t"
|
||||
|
||||
return result
|
||||
|
||||
def _read_pages_from_database(
|
||||
self, database_id: str
|
||||
) -> tuple[list[NotionBlock], list[str]]:
|
||||
"""Returns a list of top level blocks and all page IDs in the database"""
|
||||
result_blocks: list[NotionBlock] = []
|
||||
result_pages: list[str] = []
|
||||
cursor = None
|
||||
while True:
|
||||
@@ -181,29 +289,34 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
for result in data["results"]:
|
||||
obj_id = result["id"]
|
||||
obj_type = result["object"]
|
||||
if obj_type == "page":
|
||||
logger.debug(
|
||||
f"Found page with ID '{obj_id}' in database '{database_id}'"
|
||||
)
|
||||
result_pages.append(result["id"])
|
||||
elif obj_type == "database":
|
||||
logger.debug(
|
||||
f"Found database with ID '{obj_id}' in database '{database_id}'"
|
||||
)
|
||||
result_pages.extend(self._read_pages_from_database(obj_id))
|
||||
text = self._properties_to_str(result.get("properties", {}))
|
||||
if text:
|
||||
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
|
||||
|
||||
if self.recursive_index_enabled:
|
||||
if obj_type == "page":
|
||||
logger.debug(
|
||||
f"Found page with ID '{obj_id}' in database '{database_id}'"
|
||||
)
|
||||
result_pages.append(result["id"])
|
||||
elif obj_type == "database":
|
||||
logger.debug(
|
||||
f"Found database with ID '{obj_id}' in database '{database_id}'"
|
||||
)
|
||||
# The inner contents are ignored at this level
|
||||
_, child_pages = self._read_pages_from_database(obj_id)
|
||||
result_pages.extend(child_pages)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return result_pages
|
||||
return result_blocks, result_pages
|
||||
|
||||
def _read_blocks(
|
||||
self, base_block_id: str
|
||||
) -> tuple[list[tuple[str, str]], list[str]]:
|
||||
"""Reads all child blocks for the specified block"""
|
||||
result_lines: list[tuple[str, str]] = []
|
||||
def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]:
|
||||
"""Reads all child blocks for the specified block, returns a list of blocks and child page ids"""
|
||||
result_blocks: list[NotionBlock] = []
|
||||
child_pages: list[str] = []
|
||||
cursor = None
|
||||
while True:
|
||||
@@ -211,7 +324,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
# this happens when a block is not shared with the integration
|
||||
if data is None:
|
||||
return result_lines, child_pages
|
||||
return result_blocks, child_pages
|
||||
|
||||
for result in data["results"]:
|
||||
logger.debug(
|
||||
@@ -255,46 +368,70 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
if result["has_children"]:
|
||||
if result_type == "child_page":
|
||||
# Child pages will not be included at this top level, it will be a separate document
|
||||
child_pages.append(result_block_id)
|
||||
else:
|
||||
logger.debug(f"Entering sub-block: {result_block_id}")
|
||||
subblock_result_lines, subblock_child_pages = self._read_blocks(
|
||||
subblocks, subblock_child_pages = self._read_blocks(
|
||||
result_block_id
|
||||
)
|
||||
logger.debug(f"Finished sub-block: {result_block_id}")
|
||||
result_lines.extend(subblock_result_lines)
|
||||
result_blocks.extend(subblocks)
|
||||
child_pages.extend(subblock_child_pages)
|
||||
|
||||
if result_type == "child_database" and self.recursive_index_enabled:
|
||||
child_pages.extend(self._read_pages_from_database(result_block_id))
|
||||
if result_type == "child_database":
|
||||
inner_blocks, inner_child_pages = self._read_pages_from_database(
|
||||
result_block_id
|
||||
)
|
||||
# A database on a page often looks like a table, we need to include it for the contents
|
||||
# of the page but the children (cells) should be processed as other Documents
|
||||
result_blocks.extend(inner_blocks)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
if cur_result_text:
|
||||
result_lines.append((cur_result_text, result_block_id))
|
||||
if self.recursive_index_enabled:
|
||||
child_pages.extend(inner_child_pages)
|
||||
|
||||
if cur_result_text_arr:
|
||||
new_block = NotionBlock(
|
||||
id=result_block_id,
|
||||
text="\n".join(cur_result_text_arr),
|
||||
prefix="\n",
|
||||
)
|
||||
result_blocks.append(new_block)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return result_lines, child_pages
|
||||
return result_blocks, child_pages
|
||||
|
||||
def _read_page_title(self, page: NotionPage) -> str:
|
||||
def _read_page_title(self, page: NotionPage) -> str | None:
|
||||
"""Extracts the title from a Notion page"""
|
||||
page_title = None
|
||||
if hasattr(page, "database_name") and page.database_name:
|
||||
return page.database_name
|
||||
for _, prop in page.properties.items():
|
||||
if prop["type"] == "title" and len(prop["title"]) > 0:
|
||||
page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip()
|
||||
break
|
||||
if page_title is None:
|
||||
page_title = f"Untitled Page [{page.id}]"
|
||||
|
||||
return page_title
|
||||
|
||||
def _read_pages(
|
||||
self,
|
||||
pages: list[NotionPage],
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Reads pages for rich text content and generates Documents"""
|
||||
"""Reads pages for rich text content and generates Documents
|
||||
|
||||
Note that a page which is turned into a "wiki" becomes a database but both top level pages and top level databases
|
||||
do not seem to have any properties associated with them.
|
||||
|
||||
Pages that are part of a database can have properties which are like the values of the row in the "database" table
|
||||
in which they exist
|
||||
|
||||
This is not clearly outlined in the Notion API docs but it is observable empirically.
|
||||
https://developers.notion.com/docs/working-with-page-content
|
||||
"""
|
||||
all_child_page_ids: list[str] = []
|
||||
for page in pages:
|
||||
if page.id in self.indexed_pages:
|
||||
@@ -304,18 +441,23 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
logger.info(f"Reading page with ID '{page.id}', with url {page.url}")
|
||||
page_blocks, child_page_ids = self._read_blocks(page.id)
|
||||
all_child_page_ids.extend(child_page_ids)
|
||||
page_title = self._read_page_title(page)
|
||||
|
||||
if not page_blocks:
|
||||
continue
|
||||
|
||||
page_title = (
|
||||
self._read_page_title(page) or f"Untitled Page with ID {page.id}"
|
||||
)
|
||||
|
||||
yield (
|
||||
Document(
|
||||
id=page.id,
|
||||
# Will add title to the first section later in processing
|
||||
sections=[Section(link=page.url, text="")]
|
||||
+ [
|
||||
sections=[
|
||||
Section(
|
||||
link=f"{page.url}#{block_id.replace('-', '')}",
|
||||
text=block_text,
|
||||
link=f"{page.url}#{block.id.replace('-', '')}",
|
||||
text=block.prefix + block.text,
|
||||
)
|
||||
for block_text, block_id in page_blocks
|
||||
for block in page_blocks
|
||||
],
|
||||
source=DocumentSource.NOTION,
|
||||
semantic_identifier=page_title,
|
||||
|
||||
@@ -40,8 +40,8 @@ def _convert_driveitem_to_document(
|
||||
driveitem: DriveItem,
|
||||
) -> Document:
|
||||
file_text = extract_file_text(
|
||||
file_name=driveitem.name,
|
||||
file=io.BytesIO(driveitem.get_content().execute_query().value),
|
||||
file_name=driveitem.name,
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,13 +8,12 @@ from typing import cast
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
@@ -23,9 +22,8 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.connectors.slack.utils import get_message_link
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_logged
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_paginated
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -38,47 +36,18 @@ MessageType = dict[str, Any]
|
||||
# list of messages in a thread
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
basic_retry_wrapper = retry_builder()
|
||||
|
||||
|
||||
def _make_paginated_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
||||
)
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def _make_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
|
||||
"""Get information about a channel. Needed to convert channel ID to channel name"""
|
||||
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
|
||||
"channel"
|
||||
]
|
||||
|
||||
|
||||
def _get_channels(
|
||||
def _collect_paginated_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool,
|
||||
get_private: bool,
|
||||
channel_types: list[str],
|
||||
) -> list[ChannelType]:
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_list,
|
||||
exclude_archived=exclude_archived,
|
||||
# also get private channels the bot is added to
|
||||
types=["public_channel", "private_channel"]
|
||||
if get_private
|
||||
else ["public_channel"],
|
||||
types=channel_types,
|
||||
):
|
||||
channels.extend(result["channels"])
|
||||
|
||||
@@ -88,19 +57,38 @@ def _get_channels(
|
||||
def get_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool = True,
|
||||
get_public: bool = True,
|
||||
get_private: bool = True,
|
||||
) -> list[ChannelType]:
|
||||
"""Get all channels in the workspace"""
|
||||
channels: list[dict[str, Any]] = []
|
||||
channel_types = []
|
||||
if get_public:
|
||||
channel_types.append("public_channel")
|
||||
if get_private:
|
||||
channel_types.append("private_channel")
|
||||
# try getting private channels as well at first
|
||||
try:
|
||||
return _get_channels(
|
||||
client=client, exclude_archived=exclude_archived, get_private=True
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.info(f"Unable to fetch private channels due to - {e}")
|
||||
logger.info("trying again without private channels")
|
||||
if get_public:
|
||||
channel_types = ["public_channel"]
|
||||
else:
|
||||
logger.warning("No channels to fetch")
|
||||
return []
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
|
||||
return _get_channels(
|
||||
client=client, exclude_archived=exclude_archived, get_private=False
|
||||
)
|
||||
return channels
|
||||
|
||||
|
||||
def get_channel_messages(
|
||||
@@ -112,14 +100,14 @@ def get_channel_messages(
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
if not channel["is_member"]:
|
||||
_make_slack_api_call(
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
@@ -131,7 +119,7 @@ def get_channel_messages(
|
||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||
"""Get all messages in a thread"""
|
||||
threads: list[MessageType] = []
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||
):
|
||||
threads.extend(result["messages"])
|
||||
@@ -217,12 +205,17 @@ _DISALLOWED_MSG_SUBTYPES = {
|
||||
"group_leave",
|
||||
"group_archive",
|
||||
"group_unarchive",
|
||||
"channel_leave",
|
||||
"channel_name",
|
||||
"channel_join",
|
||||
}
|
||||
|
||||
|
||||
def _default_msg_filter(message: MessageType) -> bool:
|
||||
def default_msg_filter(message: MessageType) -> bool:
|
||||
# Don't keep messages from bots
|
||||
if message.get("bot_id") or message.get("app_id"):
|
||||
if message.get("bot_profile", {}).get("name") == "DanswerConnector":
|
||||
return False
|
||||
return True
|
||||
|
||||
# Uninformative
|
||||
@@ -266,14 +259,14 @@ def filter_channels(
|
||||
]
|
||||
|
||||
|
||||
def get_all_docs(
|
||||
def _get_all_docs(
|
||||
client: WebClient,
|
||||
workspace: str,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Get all documents in the workspace, channel by channel"""
|
||||
slack_cleaner = SlackTextCleaner(client=client)
|
||||
@@ -328,7 +321,44 @@ def get_all_docs(
|
||||
)
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector):
|
||||
def _get_all_doc_ids(
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
|
||||
This makes it an order of magnitude faster than get_all_docs
|
||||
"""
|
||||
|
||||
all_channels = get_channels(client)
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
all_doc_ids = set()
|
||||
for channel in filtered_channels:
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
if msg_filter_func(message):
|
||||
continue
|
||||
|
||||
# The document id is the channel id and the ts of the first message in the thread
|
||||
# Since we already have the first message of the thread, we dont have to
|
||||
# fetch the thread for id retrieval, saving time and API calls
|
||||
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
|
||||
|
||||
return all_doc_ids
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector, IdConnector):
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
@@ -349,6 +379,16 @@ class SlackPollConnector(PollConnector):
|
||||
self.client = WebClient(token=bot_token)
|
||||
return None
|
||||
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
return _get_all_doc_ids(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
@@ -356,7 +396,7 @@ class SlackPollConnector(PollConnector):
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
documents: list[Document] = []
|
||||
for document in get_all_docs(
|
||||
for document in _get_all_docs(
|
||||
client=self.client,
|
||||
workspace=self.workspace,
|
||||
channels=self.channels,
|
||||
|
||||
@@ -10,11 +10,13 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_retry_wrapper = retry_builder()
|
||||
# number of messages we request per page when fetching paginated slack messages
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
@@ -34,7 +36,7 @@ def get_message_link(
|
||||
)
|
||||
|
||||
|
||||
def make_slack_api_call_logged(
|
||||
def _make_slack_api_call_logged(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., SlackResponse]:
|
||||
@wraps(call)
|
||||
@@ -47,7 +49,7 @@ def make_slack_api_call_logged(
|
||||
return logged_call
|
||||
|
||||
|
||||
def make_slack_api_call_paginated(
|
||||
def _make_slack_api_call_paginated(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., Generator[dict[str, Any], None, None]]:
|
||||
"""Wraps calls to slack API so that they automatically handle pagination"""
|
||||
@@ -116,6 +118,24 @@ def make_slack_api_rate_limited(
|
||||
return rate_limited_call
|
||||
|
||||
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def expert_info_from_slack_id(
|
||||
user_id: str | None,
|
||||
client: WebClient,
|
||||
|
||||
@@ -128,6 +128,9 @@ def get_internal_links(
|
||||
if not href:
|
||||
continue
|
||||
|
||||
# Account for malformed backslashes in URLs
|
||||
href = href.replace("\\", "/")
|
||||
|
||||
if should_ignore_pound and "#" in href:
|
||||
href = href.split("#")[0]
|
||||
|
||||
|
||||
244
backend/danswer/connectors/xenforo/connector.py
Normal file
244
backend/danswer/connectors/xenforo/connector.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
This is the XenforoConnector class. It is used to connect to a Xenforo forum and load or update documents from the forum.
|
||||
|
||||
To use this class, you need to provide the URL of the Xenforo forum board you want to connect to when creating an instance
|
||||
of the class. The URL should be a string that starts with 'http://' or 'https://', followed by the domain name of the
|
||||
forum, followed by the board name. For example:
|
||||
|
||||
base_url = 'https://www.example.com/forum/boards/some-topic/'
|
||||
|
||||
The `load_from_state` method is used to load documents from the forum. It takes an optional `state` parameter, which
|
||||
can be used to specify a state from which to start loading documents.
|
||||
"""
|
||||
import re
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytz
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import Tag
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_title(soup: BeautifulSoup) -> str:
|
||||
el = soup.find("h1", "p-title-value")
|
||||
if not el:
|
||||
return ""
|
||||
title = el.text
|
||||
for char in (";", ":", "!", "*", "/", "\\", "?", '"', "<", ">", "|"):
|
||||
title = title.replace(char, "_")
|
||||
return title
|
||||
|
||||
|
||||
def get_pages(soup: BeautifulSoup, url: str) -> list[str]:
|
||||
page_tags = soup.select("li.pageNav-page")
|
||||
page_numbers = []
|
||||
for button in page_tags:
|
||||
if re.match(r"^\d+$", button.text):
|
||||
page_numbers.append(button.text)
|
||||
|
||||
max_pages = int(max(page_numbers, key=int)) if page_numbers else 1
|
||||
|
||||
all_pages = []
|
||||
for x in range(1, int(max_pages) + 1):
|
||||
all_pages.append(f"{url}page-{x}")
|
||||
return all_pages
|
||||
|
||||
|
||||
def parse_post_date(post_element: BeautifulSoup) -> datetime:
|
||||
el = post_element.find("time")
|
||||
if not isinstance(el, Tag) or "datetime" not in el.attrs:
|
||||
return datetime.utcfromtimestamp(0).replace(tzinfo=timezone.utc)
|
||||
|
||||
date_value = el["datetime"]
|
||||
|
||||
# Ensure date_value is a string (if it's a list, take the first element)
|
||||
if isinstance(date_value, list):
|
||||
date_value = date_value[0]
|
||||
|
||||
post_date = datetime.strptime(date_value, "%Y-%m-%dT%H:%M:%S%z")
|
||||
return datetime_to_utc(post_date)
|
||||
|
||||
|
||||
def scrape_page_posts(
|
||||
soup: BeautifulSoup,
|
||||
page_index: int,
|
||||
url: str,
|
||||
initial_run: bool,
|
||||
start_time: datetime,
|
||||
) -> list:
|
||||
title = get_title(soup)
|
||||
|
||||
documents = []
|
||||
for post in soup.find_all("div", class_="message-inner"):
|
||||
post_date = parse_post_date(post)
|
||||
if initial_run or post_date > start_time:
|
||||
el = post.find("div", class_="bbWrapper")
|
||||
if not el:
|
||||
continue
|
||||
post_text = el.get_text(strip=True) + "\n"
|
||||
author_tag = post.find("a", class_="username")
|
||||
if author_tag is None:
|
||||
author_tag = post.find("span", class_="username")
|
||||
author = author_tag.get_text(strip=True) if author_tag else "Deleted author"
|
||||
formatted_time = post_date.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# TODO: if a caller calls this for each page of a thread, it may see the
|
||||
# same post multiple times if there is a sticky post
|
||||
# that appears on each page of a thread.
|
||||
# it's important to generate unique doc id's, so page index is part of the
|
||||
# id. We may want to de-dupe this stuff inside the indexing service.
|
||||
document = Document(
|
||||
id=f"{DocumentSource.XENFORO.value}_{title}_{page_index}_{formatted_time}",
|
||||
sections=[Section(link=url, text=post_text)],
|
||||
title=title,
|
||||
source=DocumentSource.XENFORO,
|
||||
semantic_identifier=title,
|
||||
primary_owners=[BasicExpertInfo(display_name=author)],
|
||||
metadata={
|
||||
"type": "post",
|
||||
"author": author,
|
||||
"time": formatted_time,
|
||||
},
|
||||
doc_updated_at=post_date,
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
return documents
|
||||
|
||||
|
||||
class XenforoConnector(LoadConnector):
|
||||
# Class variable to track if the connector has been run before
|
||||
has_been_run_before = False
|
||||
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url = base_url
|
||||
self.initial_run = not XenforoConnector.has_been_run_before
|
||||
self.start = datetime.utcnow().replace(tzinfo=pytz.utc) - timedelta(days=1)
|
||||
self.cookies: dict[str, str] = {}
|
||||
# mimic user browser to avoid being blocked by the website (see: https://www.useragents.me/)
|
||||
self.headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/121.0.0.0 Safari/537.36"
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if credentials:
|
||||
logger.warning("Unexpected credentials provided for Xenforo Connector")
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
# Standardize URL to always end in /.
|
||||
if self.base_url[-1] != "/":
|
||||
self.base_url += "/"
|
||||
|
||||
# Remove all extra parameters from the end such as page, post.
|
||||
matches = ("threads/", "boards/", "forums/")
|
||||
for each in matches:
|
||||
if each in self.base_url:
|
||||
try:
|
||||
self.base_url = self.base_url[
|
||||
0 : self.base_url.index(
|
||||
"/", self.base_url.index(each) + len(each)
|
||||
)
|
||||
+ 1
|
||||
]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
all_threads = []
|
||||
|
||||
# If the URL contains "boards/" or "forums/", find all threads.
|
||||
if "boards/" in self.base_url or "forums/" in self.base_url:
|
||||
pages = get_pages(self.requestsite(self.base_url), self.base_url)
|
||||
|
||||
# Get all pages on thread_list_page
|
||||
for pre_count, thread_list_page in enumerate(pages, start=1):
|
||||
logger.info(
|
||||
f"Getting pages from thread_list_page.. Current: {pre_count}/{len(pages)}\r"
|
||||
)
|
||||
all_threads += self.get_threads(thread_list_page)
|
||||
# If the URL contains "threads/", add the thread to the list.
|
||||
elif "threads/" in self.base_url:
|
||||
all_threads.append(self.base_url)
|
||||
|
||||
# Process all threads
|
||||
for thread_count, thread_url in enumerate(all_threads, start=1):
|
||||
soup = self.requestsite(thread_url)
|
||||
if soup is None:
|
||||
logger.error(f"Failed to load page: {self.base_url}")
|
||||
continue
|
||||
pages = get_pages(soup, thread_url)
|
||||
# Getting all pages for all threads
|
||||
for page_index, page in enumerate(pages, start=1):
|
||||
logger.info(
|
||||
f"Progress: Page {page_index}/{len(pages)} - Thread {thread_count}/{len(all_threads)}\r"
|
||||
)
|
||||
soup_page = self.requestsite(page)
|
||||
doc_batch.extend(
|
||||
scrape_page_posts(
|
||||
soup_page, page_index, thread_url, self.initial_run, self.start
|
||||
)
|
||||
)
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
# Mark the initial run finished after all threads and pages have been processed
|
||||
XenforoConnector.has_been_run_before = True
|
||||
|
||||
def get_threads(self, url: str) -> list[str]:
|
||||
soup = self.requestsite(url)
|
||||
thread_tags = soup.find_all(class_="structItem-title")
|
||||
base_url = "{uri.scheme}://{uri.netloc}".format(uri=urlparse(url))
|
||||
threads = []
|
||||
for x in thread_tags:
|
||||
y = x.find_all(href=True)
|
||||
for element in y:
|
||||
link = element["href"]
|
||||
if "threads/" in link:
|
||||
stripped = link[0 : link.rfind("/") + 1]
|
||||
if base_url + stripped not in threads:
|
||||
threads.append(base_url + stripped)
|
||||
return threads
|
||||
|
||||
def requestsite(self, url: str) -> BeautifulSoup:
|
||||
try:
|
||||
response = requests.get(
|
||||
url, cookies=self.cookies, headers=self.headers, timeout=10
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"<{url}> Request Error: {response.status_code} - {response.reason}"
|
||||
)
|
||||
return BeautifulSoup(response.text, "html.parser")
|
||||
except TimeoutError:
|
||||
logger.error("Timed out Error.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error on {url}")
|
||||
logger.exception(e)
|
||||
return BeautifulSoup("", "html.parser")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = XenforoConnector(
|
||||
# base_url="https://cassiopaea.org/forum/threads/how-to-change-your-emotional-state.41381/"
|
||||
base_url="https://xenforo.com/community/threads/whats-new-with-enhanced-search-resource-manager-and-media-gallery-in-xenforo-2-3.220935/"
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user