mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
393 Commits
v0.12.0
...
cloud_debu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09e6bd3c9c | ||
|
|
c1803cdd56 | ||
|
|
a5b9c76012 | ||
|
|
e9b10e8b41 | ||
|
|
a0fa4adb60 | ||
|
|
ca9ba925bd | ||
|
|
833cc5c97c | ||
|
|
23ecf654ed | ||
|
|
ddc6a6d2b3 | ||
|
|
571c8ece32 | ||
|
|
884bdb4b01 | ||
|
|
b3ecf0d59f | ||
|
|
f56fda27c9 | ||
|
|
b1e4d4ea8d | ||
|
|
8db6d49fe5 | ||
|
|
28598694b1 | ||
|
|
b5d0df90b9 | ||
|
|
48be6338ec | ||
|
|
ed9014f03d | ||
|
|
2dd51230ed | ||
|
|
8b249cbe63 | ||
|
|
6b50f86cd2 | ||
|
|
bd2805b6df | ||
|
|
2847ab003e | ||
|
|
1df6a506ec | ||
|
|
f1541d1fbe | ||
|
|
dd0c4b64df | ||
|
|
788b3015bc | ||
|
|
cbbf10f450 | ||
|
|
d954914a0a | ||
|
|
bee74ac360 | ||
|
|
29ef64272a | ||
|
|
01bf6ee4b7 | ||
|
|
0502417cbe | ||
|
|
d0483dd269 | ||
|
|
eefa872d60 | ||
|
|
3f3d4da611 | ||
|
|
469068052e | ||
|
|
9032b05606 | ||
|
|
334bc6be8c | ||
|
|
814f97c2c7 | ||
|
|
4f5a2b47c4 | ||
|
|
f545508268 | ||
|
|
590986ec65 | ||
|
|
531bab5409 | ||
|
|
29c44007c4 | ||
|
|
d388643a04 | ||
|
|
8a422683e3 | ||
|
|
ddc0230d68 | ||
|
|
6711e91dbf | ||
|
|
cff2346db5 | ||
|
|
8d3fad1f12 | ||
|
|
0c3dab8e8d | ||
|
|
47735e2044 | ||
|
|
1eeab8c773 | ||
|
|
e9b41bddc9 | ||
|
|
73a86b9019 | ||
|
|
12c426c87b | ||
|
|
06aeab6d59 | ||
|
|
9b7e67004c | ||
|
|
626ce74aa3 | ||
|
|
cec63465eb | ||
|
|
5f4b31d322 | ||
|
|
ab5e515a5a | ||
|
|
699a02902a | ||
|
|
c85157f734 | ||
|
|
824844bf84 | ||
|
|
a6ab8a8da4 | ||
|
|
40719eb542 | ||
|
|
e8c72f9e82 | ||
|
|
0ba77963c4 | ||
|
|
86f2892349 | ||
|
|
64f0ad8b26 | ||
|
|
616e997dad | ||
|
|
614bd378bb | ||
|
|
7064c3d06f | ||
|
|
3bb9e4bff6 | ||
|
|
3fec7a6a30 | ||
|
|
a01a9b9a99 | ||
|
|
21ec5ed795 | ||
|
|
54dcbfa288 | ||
|
|
c69b7fc941 | ||
|
|
6722e88a7b | ||
|
|
5b5e1eb7c7 | ||
|
|
87d97d13d5 | ||
|
|
4ae3b48938 | ||
|
|
dee1a0ecd7 | ||
|
|
ca172f3306 | ||
|
|
e5d0587efa | ||
|
|
a9516202fe | ||
|
|
d23fca96c4 | ||
|
|
a45724c899 | ||
|
|
34e250407a | ||
|
|
046c0fbe3e | ||
|
|
76595facef | ||
|
|
af2d548766 | ||
|
|
7c29b1e028 | ||
|
|
a52c821e78 | ||
|
|
0770a587f1 | ||
|
|
748b79b0ef | ||
|
|
9cacb373ef | ||
|
|
21967d4b6f | ||
|
|
f5d638161b | ||
|
|
0b5013b47d | ||
|
|
1b846fbf06 | ||
|
|
cae8a131a2 | ||
|
|
72b4e8e9fe | ||
|
|
c04e2f14d9 | ||
|
|
b40a12d5d7 | ||
|
|
5e7d454ebe | ||
|
|
238509c536 | ||
|
|
d7f8cf8f18 | ||
|
|
5d810d373e | ||
|
|
9455576078 | ||
|
|
71421bb782 | ||
|
|
b88cb388b7 | ||
|
|
639986001f | ||
|
|
e7a7e78969 | ||
|
|
e255ff7d23 | ||
|
|
1be2502112 | ||
|
|
f2bedb8fdd | ||
|
|
637404f482 | ||
|
|
daae146920 | ||
|
|
d95959fb41 | ||
|
|
c667d28e7a | ||
|
|
9e0b482f47 | ||
|
|
fa84eb657f | ||
|
|
264df3441b | ||
|
|
b9bad8b7a0 | ||
|
|
600ebb6432 | ||
|
|
09fe8ea868 | ||
|
|
ad6be03b4d | ||
|
|
65d2511216 | ||
|
|
113bf19c65 | ||
|
|
6026536110 | ||
|
|
056b671cd4 | ||
|
|
8d83ae2ee8 | ||
|
|
ca988f5c5f | ||
|
|
4e4214b82c | ||
|
|
fe83f676df | ||
|
|
6d6e12119b | ||
|
|
1f2b7cb9c8 | ||
|
|
878a189011 | ||
|
|
48c10271c2 | ||
|
|
c6a79d847e | ||
|
|
1bc3f8b96f | ||
|
|
7f6a6944d6 | ||
|
|
06f4146597 | ||
|
|
7ea73d5a5a | ||
|
|
30dfe6dcb4 | ||
|
|
dc5d5dfe05 | ||
|
|
0746e0be5b | ||
|
|
970320bd49 | ||
|
|
4a7bd5578e | ||
|
|
874b098a4b | ||
|
|
ce18b63eea | ||
|
|
7a919c3589 | ||
|
|
631bac4432 | ||
|
|
53428f6e9c | ||
|
|
53b3dcbace | ||
|
|
7a3c06c2d2 | ||
|
|
7a0d823c89 | ||
|
|
db69e445d6 | ||
|
|
18e63889b7 | ||
|
|
738e60c8ed | ||
|
|
8aec873e66 | ||
|
|
7c57dde8ab | ||
|
|
f30adab853 | ||
|
|
601687a522 | ||
|
|
350cf407c9 | ||
|
|
32ec4efc7a | ||
|
|
7c6981e052 | ||
|
|
c50cd20156 | ||
|
|
14772dee71 | ||
|
|
c81e704c95 | ||
|
|
3266ef6321 | ||
|
|
c89b98b4f2 | ||
|
|
e70e0ab859 | ||
|
|
69b6e9321e | ||
|
|
7e53af18b6 | ||
|
|
b9eb1ca2ba | ||
|
|
91d44c83d2 | ||
|
|
4dbc6bb4d1 | ||
|
|
4b6a4c6bbf | ||
|
|
fd1999454a | ||
|
|
0a35422d1d | ||
|
|
69b99056b2 | ||
|
|
2a55696545 | ||
|
|
ef9942b751 | ||
|
|
993acec5e9 | ||
|
|
b01a1b509a | ||
|
|
4f994124ef | ||
|
|
14863bd457 | ||
|
|
aa1c4c635a | ||
|
|
13f6e8a6b4 | ||
|
|
66f47d294c | ||
|
|
0a685bda7d | ||
|
|
23dc8b5dad | ||
|
|
cd5f2293ad | ||
|
|
6c2269e565 | ||
|
|
46315cddf1 | ||
|
|
5f28a1b0e4 | ||
|
|
9e9b7ed61d | ||
|
|
3fb2bfefec | ||
|
|
7c618c9d17 | ||
|
|
03e2789392 | ||
|
|
2783fa08a3 | ||
|
|
edeaee93a2 | ||
|
|
5385bae100 | ||
|
|
813445ab59 | ||
|
|
af814823c8 | ||
|
|
607f61eaeb | ||
|
|
de66f7adb2 | ||
|
|
3432d932d1 | ||
|
|
9bd0cb9eb5 | ||
|
|
f12eb4a5cf | ||
|
|
16863de0aa | ||
|
|
63d1eefee5 | ||
|
|
e338677896 | ||
|
|
7be80c4af9 | ||
|
|
7f1e4a02bf | ||
|
|
5be7d27285 | ||
|
|
fd84b7a768 | ||
|
|
36941ae663 | ||
|
|
212353ed4a | ||
|
|
eb8708f770 | ||
|
|
ac448956e9 | ||
|
|
634a0b9398 | ||
|
|
09d3e47c03 | ||
|
|
9c0cc94f15 | ||
|
|
07dfde2209 | ||
|
|
28e2b78b2e | ||
|
|
0553062ac6 | ||
|
|
284e375ba3 | ||
|
|
1f2f7d0ac2 | ||
|
|
2ecc28b57d | ||
|
|
77cf9b3539 | ||
|
|
076ce2ebd0 | ||
|
|
b625ee32a7 | ||
|
|
c32b93fcc3 | ||
|
|
1c8476072e | ||
|
|
7573416ca1 | ||
|
|
86d8666481 | ||
|
|
8abcde91d4 | ||
|
|
3466451d51 | ||
|
|
413891f143 | ||
|
|
7a0a4d4b79 | ||
|
|
a3439605a5 | ||
|
|
694e79f5e1 | ||
|
|
5dfafc8612 | ||
|
|
62a4aa10db | ||
|
|
a357cdc4c9 | ||
|
|
84615abfdd | ||
|
|
8ae6b1960b | ||
|
|
d9b87bbbc2 | ||
|
|
a0065b01af | ||
|
|
c5306148a3 | ||
|
|
1e17934de4 | ||
|
|
93add96ccc | ||
|
|
3a466a4b08 | ||
|
|
85cbd9caed | ||
|
|
9dc23bf3e7 | ||
|
|
e32809f7ca | ||
|
|
3e58f9f8ab | ||
|
|
2381c8d498 | ||
|
|
c6dadb24dc | ||
|
|
5dc07d4178 | ||
|
|
129c8f8faf | ||
|
|
67bfcabbc5 | ||
|
|
9819aa977a | ||
|
|
8d5b8a4028 | ||
|
|
682319d2e9 | ||
|
|
fe1400aa36 | ||
|
|
e3573b2bc1 | ||
|
|
35b5c44cc7 | ||
|
|
5eddc89b5a | ||
|
|
9a492ceb6d | ||
|
|
3c54ae9de9 | ||
|
|
13f08f3ebb | ||
|
|
bd9f15854f | ||
|
|
366aa2a8ea | ||
|
|
deee237c7e | ||
|
|
100b4a0d16 | ||
|
|
70207b4b39 | ||
|
|
50826b6bef | ||
|
|
3f648cbc31 | ||
|
|
c875a4774f | ||
|
|
049091eb01 | ||
|
|
3dac24542b | ||
|
|
194dcb593d | ||
|
|
bf291d0c0a | ||
|
|
8309f4a802 | ||
|
|
0ff2565125 | ||
|
|
e89dcd7f84 | ||
|
|
645e7e828e | ||
|
|
2a54f14195 | ||
|
|
9209fc804b | ||
|
|
b712877701 | ||
|
|
e6df32dcc3 | ||
|
|
eb81258a23 | ||
|
|
487ef4acc0 | ||
|
|
9b7cc83eae | ||
|
|
ce3124f9e4 | ||
|
|
e69303e309 | ||
|
|
6e698ac84a | ||
|
|
d69180aeb8 | ||
|
|
aa37051be9 | ||
|
|
a7d95661b3 | ||
|
|
33ee899408 | ||
|
|
954b5b2a56 | ||
|
|
521425a4f2 | ||
|
|
618bc02d54 | ||
|
|
b7de74fdf8 | ||
|
|
6e83fe3a39 | ||
|
|
259fc049b7 | ||
|
|
7015e6f2ab | ||
|
|
24be13c015 | ||
|
|
ddff7ecc3f | ||
|
|
97932dc44b | ||
|
|
637b6d9e75 | ||
|
|
54dc1ac917 | ||
|
|
21d5cc43f8 | ||
|
|
7c841051ed | ||
|
|
6e91964924 | ||
|
|
facf1d55a0 | ||
|
|
d68f8d6fbc | ||
|
|
65a205d488 | ||
|
|
485f3f72fa | ||
|
|
dcbea883ae | ||
|
|
a50a3944b3 | ||
|
|
60471b6a73 | ||
|
|
d703e694ce | ||
|
|
6066042fef | ||
|
|
eb0e20b9e4 | ||
|
|
490a68773b | ||
|
|
227aff1e47 | ||
|
|
6e29d1944c | ||
|
|
22189f02c6 | ||
|
|
fdc4811fce | ||
|
|
021d0cf314 | ||
|
|
942e47db29 | ||
|
|
f4a020b599 | ||
|
|
5166649eae | ||
|
|
ba805f766f | ||
|
|
9d57f34c34 | ||
|
|
cc2f584321 | ||
|
|
a1b95df3b8 | ||
|
|
9272d6ebfe | ||
|
|
4fb65dcf73 | ||
|
|
2bbc5d5d07 | ||
|
|
950b1c38f2 | ||
|
|
99fbfba32f | ||
|
|
0a59efe64a | ||
|
|
cf5d394d39 | ||
|
|
f6d8f5ca89 | ||
|
|
1fb4cdfcc3 | ||
|
|
ac51469bcb | ||
|
|
c25f164e28 | ||
|
|
813720905b | ||
|
|
0c45488ac6 | ||
|
|
95d9b33c1a | ||
|
|
55919f596c | ||
|
|
1d0fb6d012 | ||
|
|
2b1dbde829 | ||
|
|
2758ffd9d5 | ||
|
|
07a1b49b4f | ||
|
|
43d8daa5bc | ||
|
|
faeb9f09f0 | ||
|
|
25f5c12750 | ||
|
|
2d81710ccc | ||
|
|
187a7d2da2 | ||
|
|
4b152aa3a7 | ||
|
|
06f937cf93 | ||
|
|
5a24ed2947 | ||
|
|
2372e6a5a5 | ||
|
|
3eef4e3992 | ||
|
|
467ce4e3f3 | ||
|
|
ee4b334a0a | ||
|
|
4087292001 | ||
|
|
da6ed5b2b3 | ||
|
|
864ac2ac5c | ||
|
|
12cb77c80e | ||
|
|
583cd14bf4 | ||
|
|
001fcb3359 | ||
|
|
7ff18e0a93 | ||
|
|
9ac256e925 | ||
|
|
08600db41d | ||
|
|
6bf06ac7f7 | ||
|
|
5b06b53a3e | ||
|
|
afce57b29f | ||
|
|
257dbecd1d | ||
|
|
bd6baf39c3 | ||
|
|
ddae2346ec |
@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
trivyignores: ./backend/.trivyignore
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL onyxdotapp/onyx-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
trivyignores: ./backend/.trivyignore
|
||||
|
||||
@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-cloud-web-server
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
@@ -28,11 +28,11 @@ jobs:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
@@ -41,16 +41,16 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
@@ -60,22 +60,23 @@ jobs:
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NEXT_PUBLIC_CLOUD_ENABLED=true
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
@@ -95,42 +96,42 @@ jobs:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
@@ -3,53 +3,121 @@ name: Build and Push Model Server Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-model-server
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDKIT_PROGRESS: plain
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
build-amd64:
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: System Info
|
||||
run: |
|
||||
df -h
|
||||
free -h
|
||||
docker system prune -af --volumes
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: |
|
||||
image=moby/buildkit:latest
|
||||
network=host
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
- name: Build and Push AMD64
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
build-arm64:
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: System Info
|
||||
run: |
|
||||
df -h
|
||||
free -h
|
||||
docker system prune -af --volumes
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: |
|
||||
image=moby/buildkit:latest
|
||||
network=host
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and Push ARM64
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
merge-and-scan:
|
||||
needs: [build-amd64, build-arm64]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and Push Multi-arch Manifest
|
||||
run: |
|
||||
docker buildx create --use
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout: "10m"
|
||||
|
||||
@@ -3,12 +3,12 @@ name: Build and Push Web Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
@@ -27,11 +27,11 @@ jobs:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
@@ -40,16 +40,16 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
@@ -59,18 +59,18 @@ jobs:
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
@@ -90,42 +90,42 @@ jobs:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
34
.github/workflows/docker-tag-latest.yml
vendored
34
.github/workflows/docker-tag-latest.yml
vendored
@@ -7,31 +7,31 @@ on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'The version (ie v0.0.1) to tag as latest'
|
||||
description: "The version (ie v0.0.1) to tag as latest"
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
tag:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Enable Docker CLI experimental features
|
||||
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
|
||||
- name: Enable Docker CLI experimental features
|
||||
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
|
||||
|
||||
- name: Pull, Tag and Push Web Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t danswer/danswer-web-server:latest danswer/danswer-web-server:${{ github.event.inputs.version }}
|
||||
- name: Pull, Tag and Push Web Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
|
||||
|
||||
- name: Pull, Tag and Push API Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t danswer/danswer-backend:latest danswer/danswer-backend:${{ github.event.inputs.version }}
|
||||
- name: Pull, Tag and Push API Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
|
||||
|
||||
27
.github/workflows/hotfix-release-branches.yml
vendored
27
.github/workflows/hotfix-release-branches.yml
vendored
@@ -8,43 +8,42 @@ on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
hotfix_commit:
|
||||
description: 'Hotfix commit hash'
|
||||
description: "Hotfix commit hash"
|
||||
required: true
|
||||
hotfix_suffix:
|
||||
description: 'Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})'
|
||||
description: "Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})"
|
||||
required: true
|
||||
release_branch_pattern:
|
||||
description: 'Release branch pattern (regex)'
|
||||
description: "Release branch pattern (regex)"
|
||||
required: true
|
||||
default: 'release/.*'
|
||||
default: "release/.*"
|
||||
auto_merge:
|
||||
description: 'Automatically merge the hotfix PRs'
|
||||
description: "Automatically merge the hotfix PRs"
|
||||
required: true
|
||||
type: choice
|
||||
default: 'true'
|
||||
default: "true"
|
||||
options:
|
||||
- true
|
||||
- false
|
||||
|
||||
|
||||
jobs:
|
||||
hotfix_release_branches:
|
||||
permissions: write-all
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
|
||||
# needs RKUO_DEPLOY_KEY for write access to merge PR's
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
git config user.email "rkuo[bot]@onyx.app"
|
||||
|
||||
- name: Fetch All Branches
|
||||
run: |
|
||||
@@ -62,10 +61,10 @@ jobs:
|
||||
echo "No release branches found matching pattern '${{ github.event.inputs.release_branch_pattern }}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "Found release branches:"
|
||||
echo "$BRANCHES"
|
||||
|
||||
|
||||
# Join the branches into a single line separated by commas
|
||||
BRANCHES_JOINED=$(echo "$BRANCHES" | tr '\n' ',' | sed 's/,$//')
|
||||
|
||||
@@ -169,4 +168,4 @@ jobs:
|
||||
echo "Failed to merge pull request #$PR_NUMBER."
|
||||
fi
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
20
.github/workflows/pr-backport-autotrigger.yml
vendored
20
.github/workflows/pr-backport-autotrigger.yml
vendored
@@ -4,7 +4,7 @@ name: Backport on Merge
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed] # Later we check for merge so only PRs that go in can get backported
|
||||
types: [closed] # Later we check for merge so only PRs that go in can get backported
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -26,9 +26,9 @@ jobs:
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
git config user.email "rkuo[bot]@onyx.app"
|
||||
git fetch --prune
|
||||
|
||||
|
||||
- name: Check for Backport Checkbox
|
||||
id: checkbox-check
|
||||
run: |
|
||||
@@ -51,14 +51,14 @@ jobs:
|
||||
# Fetch latest tags for beta and stable
|
||||
LATEST_BETA_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*-beta.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$" | grep -v -- "-cloud" | sort -Vr | head -n 1)
|
||||
LATEST_STABLE_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+$" | sort -Vr | head -n 1)
|
||||
|
||||
|
||||
# Handle case where no beta tags exist
|
||||
if [[ -z "$LATEST_BETA_TAG" ]]; then
|
||||
NEW_BETA_TAG="v1.0.0-beta.1"
|
||||
else
|
||||
NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 "." $3 "-beta." ($NF+1)}')
|
||||
fi
|
||||
|
||||
|
||||
# Increment latest stable tag
|
||||
NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
|
||||
echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
@@ -80,10 +80,10 @@ jobs:
|
||||
run: |
|
||||
set -e
|
||||
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
|
||||
|
||||
|
||||
# Echo the merge commit SHA
|
||||
echo "Merge commit SHA: ${{ github.event.pull_request.merge_commit_sha }}"
|
||||
|
||||
|
||||
# Fetch all history for all branches and tags
|
||||
git fetch --prune
|
||||
|
||||
@@ -98,7 +98,7 @@ jobs:
|
||||
echo "Cherry-pick to beta failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
|
||||
# Create new beta branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
# Push the changes and tag to the beta branch using PAT
|
||||
@@ -110,13 +110,13 @@ jobs:
|
||||
echo "Last 5 commits on stable branch:"
|
||||
git log -n 5 --pretty=format:"%H"
|
||||
echo "" # Newline for formatting
|
||||
|
||||
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to stable failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
|
||||
# Create new stable branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
# Push the changes and tag to the stable branch using PAT
|
||||
|
||||
238
.github/workflows/pr-chromatic-tests.yml
vendored
Normal file
238
.github/workflows/pr-chromatic-tests.yml
vendored
Normal file
@@ -0,0 +1,238 @@
|
||||
name: Run Chromatic Tests
|
||||
concurrency:
|
||||
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on: push
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
name: Playwright Tests
|
||||
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Install playwright browsers
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
|
||||
- name: Build Web Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-web-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
GEN_AI_API_KEY=${{ secrets.OPENAI_API_KEY }} \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run pytest playwright test init
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTEST_IGNORE_SKIP: true
|
||||
run: pytest -s tests/integration/tests/playwright/test_playwright.py
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./web
|
||||
run: npx playwright test
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
# Chromatic automatically defaults to the test-results directory.
|
||||
# Replace with the path to your custom directory and adjust the CHROMATIC_ARCHIVE_LOCATION environment variable accordingly.
|
||||
name: test-results
|
||||
path: ./web/test-results
|
||||
retention-days: 30
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
chromatic-tests:
|
||||
name: Chromatic Tests
|
||||
|
||||
needs: playwright-tests
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Download Playwright test results
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: test-results
|
||||
path: ./web/test-results
|
||||
|
||||
- name: Run Chromatic
|
||||
uses: chromaui/action@latest
|
||||
with:
|
||||
playwright: true
|
||||
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
workingDir: ./web
|
||||
env:
|
||||
CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
72
.github/workflows/pr-helm-chart-testing.yml
vendored
Normal file
72
.github/workflows/pr-helm-chart-testing.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
echo "default_branch: ${{ github.event.repository.default_branch }}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
@@ -1,68 +0,0 @@
|
||||
# This workflow is intentionally disabled while we're still working on it
|
||||
# It's close to ready, but a race condition needs to be fixed with
|
||||
# API server and Vespa startup, and it needs to have a way to build/test against
|
||||
# local containers
|
||||
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Run chart-testing (lint)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --config ct.yaml
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
@@ -8,16 +8,19 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
- "release/**"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -33,21 +36,21 @@ jobs:
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull danswer/danswer-web-server:latest
|
||||
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:test
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
@@ -55,7 +58,7 @@ jobs:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-backend:test
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
@@ -67,19 +70,19 @@ jobs:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-model-server:test
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-integration:test
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
@@ -116,7 +119,7 @@ jobs:
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
danswer/danswer-integration:test \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
id: run_multitenant_tests
|
||||
@@ -128,15 +131,14 @@ jobs:
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
|
||||
- name: Start Docker containers
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
@@ -150,12 +152,12 @@ jobs:
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
@@ -195,9 +197,13 @@ jobs:
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/tests
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
@@ -210,18 +216,19 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
@@ -20,9 +20,12 @@ env:
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
# Google
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
79
.github/workflows/tag-nightly.yml
vendored
79
.github/workflows/tag-nightly.yml
vendored
@@ -2,53 +2,52 @@ name: Nightly Tag Push
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 10 * * *' # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
|
||||
- cron: "0 10 * * *" # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
|
||||
|
||||
permissions:
|
||||
contents: write # Allows pushing tags to the repository
|
||||
contents: write # Allows pushing tags to the repository
|
||||
|
||||
jobs:
|
||||
create-and-push-tag:
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@onyx.app"
|
||||
|
||||
- name: Check for existing nightly tag
|
||||
id: check_tag
|
||||
run: |
|
||||
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
|
||||
echo "A tag starting with 'nightly-latest' already exists on HEAD."
|
||||
echo "tag_exists=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No tag starting with 'nightly-latest' exists on HEAD."
|
||||
echo "tag_exists=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# don't tag again if HEAD already has a nightly-latest tag on it
|
||||
- name: Create Nightly Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
env:
|
||||
DATE: ${{ github.run_id }}
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
echo "Creating tag: $TAG_NAME"
|
||||
git tag $TAG_NAME
|
||||
- name: Check for existing nightly tag
|
||||
id: check_tag
|
||||
run: |
|
||||
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
|
||||
echo "A tag starting with 'nightly-latest' already exists on HEAD."
|
||||
echo "tag_exists=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No tag starting with 'nightly-latest' exists on HEAD."
|
||||
echo "tag_exists=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Push Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
git push origin $TAG_NAME
|
||||
|
||||
# don't tag again if HEAD already has a nightly-latest tag on it
|
||||
- name: Create Nightly Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
env:
|
||||
DATE: ${{ github.run_id }}
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
echo "Creating tag: $TAG_NAME"
|
||||
git tag $TAG_NAME
|
||||
|
||||
- name: Push Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
git push origin $TAG_NAME
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,3 +7,4 @@
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
/web/test-results/
|
||||
22
.vscode/launch.template.jsonc
vendored
22
.vscode/launch.template.jsonc
vendored
@@ -17,7 +17,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Danswer Services",
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
@@ -122,7 +122,7 @@
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": [
|
||||
"danswer.main:app",
|
||||
"onyx.main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
@@ -139,7 +139,7 @@
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "danswer/danswerbot/slack/listener.py",
|
||||
"program": "onyx/onyxbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
@@ -166,7 +166,7 @@
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.primary",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
@@ -195,7 +195,7 @@
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.light",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
@@ -203,7 +203,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
@@ -224,7 +224,7 @@
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.heavy",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
@@ -232,7 +232,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
@@ -254,7 +254,7 @@
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.indexing",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
@@ -283,7 +283,7 @@
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.beat",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
],
|
||||
@@ -308,7 +308,7 @@
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
|
||||
139
CONTRIBUTING.md
139
CONTRIBUTING.md
@@ -1,105 +1,113 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md"} -->
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md"} -->
|
||||
|
||||
# Contributing to Danswer
|
||||
Hey there! We are so excited that you're interested in Danswer.
|
||||
# Contributing to Onyx
|
||||
|
||||
Hey there! We are so excited that you're interested in Onyx.
|
||||
|
||||
As an open source project in a rapidly changing space, we welcome all contributions.
|
||||
|
||||
|
||||
## 💃 Guidelines
|
||||
|
||||
### Contribution Opportunities
|
||||
The [GitHub Issues](https://github.com/danswer-ai/danswer/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
Issues marked `good first issue` are an especially great place to start.
|
||||
|
||||
**Connectors** to other tools are another great place to contribute. For details on how, refer to this
|
||||
[README.md](https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/README.md).
|
||||
[README.md](https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/connectors/README.md).
|
||||
|
||||
If you have a new/different contribution in mind, we'd love to hear about it!
|
||||
Your input is vital to making sure that Danswer moves in the right direction.
|
||||
Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
And always feel free to message us (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
|
||||
And always feel free to message us (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
|
||||
To contribute to this project, please follow the
|
||||
["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
|
||||
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
|
||||
|
||||
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
|
||||
See the [Formatting and Linting](#-formatting-and-linting) section for how to run these checks locally.
|
||||
|
||||
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
|
||||
|
||||
### Getting Help 🙋
|
||||
|
||||
Our goal is to make contributing as easy as possible. If you run into any issues please don't hesitate to reach out.
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ)
|
||||
and
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
|
||||
and
|
||||
[Discord](https://discord.gg/TDJ59cGV2X).
|
||||
|
||||
We would love to see you there!
|
||||
|
||||
|
||||
## Get Started 🚀
|
||||
Danswer being a fully functional app, relies on some external software, specifically:
|
||||
|
||||
Onyx being a fully functional app, relies on some external software, specifically:
|
||||
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to build and run Danswer locally from source with Docker containers providing the above external software. We believe this combination is easier for
|
||||
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Danswer stack within Docker below.
|
||||
|
||||
> This guide provides instructions to build and run Onyx locally from source with Docker containers providing the above external software. We believe this combination is easier for
|
||||
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Onyx stack within Docker below.
|
||||
|
||||
### Local Set Up
|
||||
|
||||
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
|
||||
|
||||
If using a lower version, modifications will have to be made to the code.
|
||||
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
|
||||
|
||||
|
||||
#### Backend: Python requirements
|
||||
|
||||
Currently, we use pip and recommend creating a virtual environment.
|
||||
|
||||
For convenience here's a command for it:
|
||||
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs.
|
||||
> For simplicity, we recommend setting up the virtual environment outside of the danswer directory.
|
||||
> This virtual environment MUST NOT be set up WITHIN the onyx directory if you plan on using mypy within certain IDEs.
|
||||
> For simplicity, we recommend setting up the virtual environment outside of the onyx directory.
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
If using PowerShell, the command slightly differs:
|
||||
|
||||
```powershell
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r danswer/backend/requirements/default.txt
|
||||
pip install -r danswer/backend/requirements/dev.txt
|
||||
pip install -r danswer/backend/requirements/ee.txt
|
||||
pip install -r danswer/backend/requirements/model_server.txt
|
||||
pip install -r onyx/backend/requirements/default.txt
|
||||
pip install -r onyx/backend/requirements/dev.txt
|
||||
pip install -r onyx/backend/requirements/ee.txt
|
||||
pip install -r onyx/backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
@@ -109,42 +117,50 @@ You may have to deactivate and reactivate your virtualenv for `playwright` to ap
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||
Once the above is done, navigate to `danswer/web` run:
|
||||
Once the above is done, navigate to `onyx/web` run:
|
||||
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
#### Docker containers for external software
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
#### Running Onyx locally
|
||||
|
||||
To start the frontend, navigate to `onyx/web` and run:
|
||||
|
||||
#### Running Danswer locally
|
||||
To start the frontend, navigate to `danswer/web` and run:
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Next, start the model server which runs the local NLP models.
|
||||
Navigate to `danswer/backend` and run:
|
||||
Navigate to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
```
|
||||
|
||||
The first time running Danswer, you will need to run the DB migrations for Postgres.
|
||||
The first time running Onyx, you will need to run the DB migrations for Postgres.
|
||||
After the first time, this is no longer required unless the DB models change.
|
||||
|
||||
Navigate to `danswer/backend` and with the venv active, run:
|
||||
Navigate to `onyx/backend` and with the venv active, run:
|
||||
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
@@ -152,21 +168,24 @@ alembic upgrade head
|
||||
Next, start the task queue which orchestrates the background jobs.
|
||||
Jobs that take more time are run async from the API server.
|
||||
|
||||
Still in `danswer/backend`, run:
|
||||
Still in `onyx/backend`, run:
|
||||
|
||||
```bash
|
||||
python ./scripts/dev_run_background_jobs.py
|
||||
```
|
||||
|
||||
To run the backend API server, navigate back to `danswer/backend` and run:
|
||||
To run the backend API server, navigate back to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
|
||||
AUTH_TYPE=disabled uvicorn onyx.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:AUTH_TYPE='disabled'
|
||||
uvicorn danswer.main:app --reload --port 8080
|
||||
uvicorn onyx.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
|
||||
@@ -182,57 +201,61 @@ You should now have 4 servers running:
|
||||
- Model server
|
||||
- Background jobs
|
||||
|
||||
Now, visit `http://localhost:3000` in your browser. You should see the Danswer onboarding wizard where you can connect your external LLM provider to Danswer.
|
||||
Now, visit `http://localhost:3000` in your browser. You should see the Onyx onboarding wizard where you can connect your external LLM provider to Onyx.
|
||||
|
||||
You've successfully set up a local Danswer instance! 🏁
|
||||
You've successfully set up a local Onyx instance! 🏁
|
||||
|
||||
#### Running the Danswer application in a container
|
||||
#### Running the Onyx application in a container
|
||||
|
||||
You can run the full Danswer application stack from pre-built images including all external software dependencies.
|
||||
You can run the full Onyx application stack from pre-built images including all external software dependencies.
|
||||
|
||||
Navigate to `danswer/deployment/docker_compose` and run:
|
||||
Navigate to `onyx/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Danswer.
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Onyx.
|
||||
|
||||
If you want to make changes to Danswer and run those changes in Docker, you can also build a local version of the Danswer container images that incorporates your changes like so:
|
||||
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes like so:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d --build
|
||||
```
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
#### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `danswer/backend` directory, run:
|
||||
Then, from the `onyx/backend` directory, run:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Danswer is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
|
||||
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
|
||||
|
||||
#### Web
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `danswer/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `danswer/web` directory.
|
||||
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
Please double check that prettier passes before creating a pull request.
|
||||
|
||||
|
||||
### Release Process
|
||||
Danswer loosely follows the SemVer versioning standard.
|
||||
|
||||
Onyx loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).
|
||||
You can see the containers [here](https://hub.docker.com/search?q=onyx%2F).
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
## Some additional notes for Mac Users
|
||||
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md).
|
||||
|
||||
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Setting up Python
|
||||
|
||||
Ensure [Homebrew](https://brew.sh/) is already set up.
|
||||
|
||||
Then install python 3.11.
|
||||
|
||||
```bash
|
||||
brew install python@3.11
|
||||
```
|
||||
|
||||
Add python 3.11 to your path: add the following line to ~/.zshrc
|
||||
|
||||
```
|
||||
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
|
||||
```
|
||||
@@ -17,15 +21,16 @@ export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
|
||||
> **Note:**
|
||||
> You will need to open a new terminal for the path change above to take effect.
|
||||
|
||||
|
||||
### Setting up Docker
|
||||
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
|
||||
|
||||
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
|
||||
ensure it is running before continuing with the docker commands.
|
||||
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly.
|
||||
After installing pre-commit, run the following command:
|
||||
|
||||
```bash
|
||||
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
```
|
||||
```
|
||||
|
||||
6
LICENSE
6
LICENSE
@@ -2,9 +2,9 @@ Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
* All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
* All third party components incorporated into the Danswer Software are licensed under the original license provided by the owner of the applicable component.
|
||||
* Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
|
||||
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
161
README.md
161
README.md
@@ -1,142 +1,143 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
|
||||
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
<p align="center">Open Source Gen-AI Chat + Unified Search.</p>
|
||||
<p align="center">Open Source Gen-AI + Enterprise Search.</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
|
||||
</a>
|
||||
<a href="https://github.com/danswer-ai/danswer/blob/main/README.md" target="_blank">
|
||||
<a href="https://github.com/onyx-dot-app/onyx/blob/main/README.md" target="_blank">
|
||||
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Danswer](https://www.danswer.ai/)</strong> is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready
|
||||
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
|
||||
configuring Personas (AI Assistants) and their Prompts.
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
|
||||
configuring AI Assistants.
|
||||
|
||||
Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
|
||||
supported?" or "Where's the pull request for feature Y?"
|
||||
|
||||
<h3>Usage</h3>
|
||||
|
||||
Danswer Web App:
|
||||
Onyx Web App:
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
https://github.com/onyx-dot-app/onyx/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
|
||||
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
|
||||
|
||||
Or, plug Danswer into your existing Slack workflows (more integrations to come 😁):
|
||||
https://github.com/onyx-dot-app/onyx/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
|
||||
|
||||
|
||||
For more details on the Admin UI to manage connectors and users, check out our
|
||||
For more details on the Admin UI to manage connectors and users, check out our
|
||||
<strong><a href="https://www.youtube.com/watch?v=geNzY1nbCnU">Full Video Demo</a></strong>!
|
||||
|
||||
## Deployment
|
||||
|
||||
Danswer can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.danswer.dev/quickstart) to learn more.
|
||||
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/danswer-ai/danswer/tree/main/deployment/kubernetes).
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
|
||||
|
||||
## 💃 Main Features
|
||||
|
||||
## 💃 Main Features
|
||||
* Chat UI with the ability to select documents to chat with.
|
||||
* Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
* Connect Danswer with LLM of your choice (self-host for a fully airgapped solution).
|
||||
* Document Search + AI Answers for natural language queries.
|
||||
* Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
* Slack integration to get answers and search results directly in Slack.
|
||||
|
||||
- Chat UI with the ability to select documents to chat with.
|
||||
- Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
- Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
|
||||
- Document Search + AI Answers for natural language queries.
|
||||
- Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
- Slack integration to get answers and search results directly in Slack.
|
||||
|
||||
## 🚧 Roadmap
|
||||
* Chat/Prompt sharing with specific teammates and user groups.
|
||||
* Multimodal model support, chat with images, video etc.
|
||||
* Choosing between LLMs and parameters during chat session.
|
||||
* Tool calling and agent configurations options.
|
||||
* Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
- Chat/Prompt sharing with specific teammates and user groups.
|
||||
- Multimodal model support, chat with images, video etc.
|
||||
- Choosing between LLMs and parameters during chat session.
|
||||
- Tool calling and agent configurations options.
|
||||
- Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
## Other Notable Benefits of Danswer
|
||||
* User Authentication with document level access management.
|
||||
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
* Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
* Custom deep learning models + learn from user feedback.
|
||||
* Easy deployment and ability to host Danswer anywhere of your choosing.
|
||||
## Other Notable Benefits of Onyx
|
||||
|
||||
- User Authentication with document level access management.
|
||||
- Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
- Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
- Custom deep learning models + learn from user feedback.
|
||||
- Easy deployment and ability to host Onyx anywhere of your choosing.
|
||||
|
||||
## 🔌 Connectors
|
||||
|
||||
Efficiently pulls the latest changes from:
|
||||
* Slack
|
||||
* GitHub
|
||||
* Google Drive
|
||||
* Confluence
|
||||
* Jira
|
||||
* Zendesk
|
||||
* Gmail
|
||||
* Notion
|
||||
* Gong
|
||||
* Slab
|
||||
* Linear
|
||||
* Productboard
|
||||
* Guru
|
||||
* Bookstack
|
||||
* Document360
|
||||
* Sharepoint
|
||||
* Hubspot
|
||||
* Local Files
|
||||
* Websites
|
||||
* And more ...
|
||||
|
||||
- Slack
|
||||
- GitHub
|
||||
- Google Drive
|
||||
- Confluence
|
||||
- Jira
|
||||
- Zendesk
|
||||
- Gmail
|
||||
- Notion
|
||||
- Gong
|
||||
- Slab
|
||||
- Linear
|
||||
- Productboard
|
||||
- Guru
|
||||
- Bookstack
|
||||
- Document360
|
||||
- Sharepoint
|
||||
- Hubspot
|
||||
- Local Files
|
||||
- Websites
|
||||
- And more ...
|
||||
|
||||
## 📚 Editions
|
||||
|
||||
There are two editions of Danswer:
|
||||
There are two editions of Onyx:
|
||||
|
||||
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
|
||||
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
* Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
* Role-based access control
|
||||
* Document permission inheritance from connected sources
|
||||
* Usage analytics and query history accessible to admins
|
||||
* Whitelabeling
|
||||
* API key authentication
|
||||
* Encryption of secrets
|
||||
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
|
||||
- Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
- Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
- Role-based access control
|
||||
- Document permission inheritance from connected sources
|
||||
- Usage analytics and query history accessible to admins
|
||||
- Whitelabeling
|
||||
- API key authentication
|
||||
- Encryption of secrets
|
||||
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
|
||||
To try the Danswer Enterprise Edition:
|
||||
To try the Onyx Enterprise Edition:
|
||||
|
||||
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
|
||||
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#danswer-ai/danswer&Date)
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
|
||||
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
LABEL com.danswer.maintainer="founders@danswer.ai"
|
||||
LABEL com.danswer.description="This image is the web/frontend container of Danswer which \
|
||||
contains code for both the Community and Enterprise editions of Danswer. If you do not \
|
||||
LABEL com.danswer.maintainer="founders@onyx.app"
|
||||
LABEL com.danswer.description="This image is the web/frontend container of Onyx which \
|
||||
contains code for both the Community and Enterprise editions of Onyx. If you do not \
|
||||
have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \
|
||||
Edition features outside of personal development or testing purposes. Please reach out to \
|
||||
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
|
||||
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.8-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
# libpq-dev needed for psycopg (postgres)
|
||||
@@ -56,7 +56,7 @@ RUN pip install --no-cache-dir --upgrade \
|
||||
# Cleanup for CVEs and size reduction
|
||||
# https://github.com/tornadoweb/tornado/issues/3107
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
|
||||
# perl-base is part of the base Python Debian image but not needed for Onyx functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
RUN apt-get update && \
|
||||
apt-get remove -y --allow-remove-essential \
|
||||
@@ -73,6 +73,7 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
@@ -91,7 +92,7 @@ COPY ./ee /app/ee
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# Set up application files
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./onyx /app/onyx
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
LABEL com.danswer.maintainer="founders@danswer.ai"
|
||||
LABEL com.danswer.description="This image is for the Danswer model server which runs all of the \
|
||||
AI models for Danswer. This container and all the code is MIT Licensed and free for all to use. \
|
||||
You can find it at https://hub.docker.com/r/danswer/danswer-model-server. For more details, \
|
||||
visit https://github.com/danswer-ai/danswer."
|
||||
LABEL com.danswer.maintainer="founders@onyx.app"
|
||||
LABEL com.danswer.description="This image is for the Onyx model server which runs all of the \
|
||||
AI models for Onyx. This container and all the code is MIT Licensed and free for all to use. \
|
||||
You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more details, \
|
||||
visit https://github.com/onyx-dot-app/onyx."
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.8-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
@@ -20,11 +20,11 @@ RUN pip install --no-cache-dir --upgrade \
|
||||
--timeout 30 \
|
||||
-r /tmp/requirements.txt
|
||||
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
# Download tokenizers, distilbert for the Danswer model
|
||||
# Download tokenizers, distilbert for the Onyx model
|
||||
# Download model weights
|
||||
# Run Nomic to pull in the custom architecture and have it cached locally
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
@@ -38,18 +38,18 @@ from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
|
||||
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
|
||||
# running Danswer, don't overwrite it with the built in cache folder
|
||||
# running Onyx, don't overwrite it with the built in cache folder
|
||||
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Utils used by model server
|
||||
COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
|
||||
COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
|
||||
|
||||
# Place to fetch version information
|
||||
COPY ./danswer/__init__.py /app/danswer/__init__.py
|
||||
COPY ./onyx/__init__.py /app/onyx/__init__.py
|
||||
|
||||
# Shared between Danswer Backend and Model Server
|
||||
# Shared between Onyx Backend and Model Server
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
|
||||
# Model Server main code
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/backend/alembic/README.md"} -->
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/backend/alembic/README.md"} -->
|
||||
|
||||
# Alembic DB Migrations
|
||||
These files are for creating/updating the tables in the Relational DB (Postgres).
|
||||
Danswer migrations use a generic single-database configuration with an async dbapi.
|
||||
|
||||
## To generate new migrations:
|
||||
run from danswer/backend:
|
||||
These files are for creating/updating the tables in the Relational DB (Postgres).
|
||||
Onyx migrations use a generic single-database configuration with an async dbapi.
|
||||
|
||||
## To generate new migrations:
|
||||
|
||||
run from onyx/backend:
|
||||
`alembic revision --autogenerate -m <DESCRIPTION_OF_MIGRATION>`
|
||||
|
||||
More info can be found here: https://alembic.sqlalchemy.org/en/latest/autogenerate.html
|
||||
|
||||
## Running migrations
|
||||
|
||||
To run all un-applied migrations:
|
||||
`alembic upgrade head`
|
||||
|
||||
To undo migrations:
|
||||
`alembic downgrade -X`
|
||||
`alembic downgrade -X`
|
||||
where X is the number of migrations you want to undo from the current state
|
||||
|
||||
@@ -1,56 +1,70 @@
|
||||
from typing import Any, Literal
|
||||
from onyx.db.engine import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Any
|
||||
import os
|
||||
import ssl
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import Base
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Add your model's MetaData object here for 'autogenerate' support
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
if USE_IAM_AUTH:
|
||||
if not os.path.exists(SSL_CERT_FILE):
|
||||
raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.")
|
||||
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
Excludes specified tables from migrations.
|
||||
"""
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
"""
|
||||
Parses command-line options passed via '-x' in Alembic commands.
|
||||
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
|
||||
"""
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
@@ -78,16 +92,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
"""
|
||||
Executes migrations in the specified schema.
|
||||
"""
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
# Set search_path to the target schema
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
context.configure(
|
||||
@@ -105,11 +115,25 @@ def do_run_migrations(
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
# Database connection settings
|
||||
region = AWS_REGION
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
|
||||
# Get IAM authentication token
|
||||
token = get_iam_auth_token(host, port, user, region)
|
||||
|
||||
# For Alembic / SQLAlchemy in this context, set SSL and password
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""
|
||||
Determines whether to run migrations for a single schema or all schemas,
|
||||
and executes migrations accordingly.
|
||||
"""
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
@@ -117,10 +141,16 @@ async def run_async_migrations() -> None:
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
if upgrade_all_tenants:
|
||||
# Run migrations for all tenant schemas sequentially
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
if upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
for schema in tenant_schemas:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
@@ -150,15 +180,20 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""
|
||||
Run migrations in 'offline' mode.
|
||||
"""
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
# Run offline migrations for all tenant schemas
|
||||
engine = create_async_engine(url)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic_offline(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
engine.sync_engine.dispose()
|
||||
|
||||
@@ -195,9 +230,6 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""
|
||||
Runs migrations in 'online' mode using an asynchronous engine.
|
||||
"""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy.sql import table
|
||||
from sqlalchemy.dialects import postgresql
|
||||
import json
|
||||
|
||||
from danswer.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0a98909f2757"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Introduce Danswer APIs
|
||||
"""Introduce Onyx APIs
|
||||
|
||||
Revision ID: 15326fcec57e
|
||||
Revises: 77d07dffae64
|
||||
@@ -8,7 +8,7 @@ Create Date: 2023-11-11 20:51:24.228999
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "15326fcec57e"
|
||||
@@ -0,0 +1,59 @@
|
||||
"""display custom llm models
|
||||
|
||||
Revision ID: 177de57c21c9
|
||||
Revises: 4ee1287bd26a
|
||||
Create Date: 2024-11-21 11:49:04.488677
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import and_
|
||||
|
||||
revision = "177de57c21c9"
|
||||
down_revision = "4ee1287bd26a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
llm_provider = sa.table(
|
||||
"llm_provider",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("provider", sa.String),
|
||||
sa.column("model_names", postgresql.ARRAY(sa.String)),
|
||||
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
|
||||
)
|
||||
|
||||
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
|
||||
|
||||
providers_to_update = sa.select(
|
||||
llm_provider.c.id,
|
||||
llm_provider.c.model_names,
|
||||
llm_provider.c.display_model_names,
|
||||
).where(
|
||||
and_(
|
||||
~llm_provider.c.provider.in_(excluded_providers),
|
||||
llm_provider.c.model_names.isnot(None),
|
||||
)
|
||||
)
|
||||
|
||||
results = conn.execute(providers_to_update).fetchall()
|
||||
|
||||
for provider_id, model_names, display_model_names in results:
|
||||
if display_model_names is None:
|
||||
display_model_names = []
|
||||
|
||||
combined_model_names = list(set(display_model_names + model_names))
|
||||
update_stmt = (
|
||||
llm_provider.update()
|
||||
.where(llm_provider.c.id == provider_id)
|
||||
.values(display_model_names=combined_model_names)
|
||||
)
|
||||
conn.execute(update_stmt)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -10,7 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1f60f60c3401"
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""default chosen assistants to none
|
||||
|
||||
Revision ID: 26b931506ecb
|
||||
Revises: 2daa494a0851
|
||||
Create Date: 2024-11-12 13:23:29.858995
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "26b931506ecb"
|
||||
down_revision = "2daa494a0851"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_new =
|
||||
CASE
|
||||
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chosen_assistants_old",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default="[-2, -1, 0]",
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_old =
|
||||
CASE
|
||||
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
|
||||
)
|
||||
30
backend/alembic/versions/2daa494a0851_add_group_sync_time.py
Normal file
30
backend/alembic/versions/2daa494a0851_add_group_sync_time.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""add-group-sync-time
|
||||
|
||||
Revision ID: 2daa494a0851
|
||||
Revises: c0fd6e4da83a
|
||||
Create Date: 2024-11-11 10:57:22.991157
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2daa494a0851"
|
||||
down_revision = "c0fd6e4da83a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_time_external_group_sync",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "last_time_external_group_sync")
|
||||
121
backend/alembic/versions/35e518e0ddf4_properly_cascade.py
Normal file
121
backend/alembic/versions/35e518e0ddf4_properly_cascade.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""properly_cascade
|
||||
|
||||
Revision ID: 35e518e0ddf4
|
||||
Revises: 91a0a4d62b14
|
||||
Create Date: 2024-09-20 21:24:04.891018
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35e518e0ddf4"
|
||||
down_revision = "91a0a4d62b14"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Update chat_message foreign key constraint
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Update chat_message__search_doc foreign key constraints
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"search_doc",
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Add CASCADE delete for tool_call foreign key
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert chat_message foreign key constraint
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Revert chat_message__search_doc foreign key constraints
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"search_doc",
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Revert tool_call foreign key constraint
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["message_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -17,7 +17,7 @@ depends_on: None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
# At this point, we directly changed some previous migrations,
|
||||
# https://github.com/danswer-ai/danswer/pull/637
|
||||
# https://github.com/onyx-dot-app/onyx/pull/637
|
||||
# Due to using Postgres native Enums, it caused some complications for first time users.
|
||||
# To remove those complications, all Enums are only handled application side moving forward.
|
||||
# This migration exists to ensure that existing users don't run into upgrade issues.
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
"""add persona categories
|
||||
|
||||
Revision ID: 47e5bef3a1d7
|
||||
Revises: dfbe9e93d3c7
|
||||
Create Date: 2024-11-05 18:55:02.221064
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "47e5bef3a1d7"
|
||||
down_revision = "dfbe9e93d3c7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the persona_category table
|
||||
op.create_table(
|
||||
"persona_category",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
# Add category_id to persona table
|
||||
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"fk_persona_category",
|
||||
"persona",
|
||||
"persona_category",
|
||||
["category_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
|
||||
op.drop_column("persona", "category_id")
|
||||
op.drop_table("persona_category")
|
||||
@@ -0,0 +1,280 @@
|
||||
"""add_multiple_slack_bot_support
|
||||
|
||||
Revision ID: 4ee1287bd26a
|
||||
Revises: 47e5bef3a1d7
|
||||
Create Date: 2024-11-06 13:15:53.302644
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.db.models import SlackBot
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4ee1287bd26a"
|
||||
down_revision = "47e5bef3a1d7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
logger.info(f"{revision}: create_table: slack_bot")
|
||||
# Create new slack_bot table
|
||||
op.create_table(
|
||||
"slack_bot",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("enabled", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("bot_token", sa.LargeBinary(), nullable=False),
|
||||
sa.Column("app_token", sa.LargeBinary(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("bot_token"),
|
||||
sa.UniqueConstraint("app_token"),
|
||||
)
|
||||
|
||||
# # Create new slack_channel_config table
|
||||
op.create_table(
|
||||
"slack_channel_config",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("slack_bot_id", sa.Integer(), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("response_type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"enable_auto_filters", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["slack_bot_id"],
|
||||
["slack_bot.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Handle existing Slack bot tokens first
|
||||
logger.info(f"{revision}: Checking for existing Slack bot.")
|
||||
bot_token = None
|
||||
app_token = None
|
||||
first_row_id = None
|
||||
|
||||
try:
|
||||
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
|
||||
except Exception:
|
||||
logger.warning("No existing Slack bot tokens found.")
|
||||
tokens = {}
|
||||
|
||||
bot_token = tokens.get("bot_token")
|
||||
app_token = tokens.get("app_token")
|
||||
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Found bot and app tokens.")
|
||||
|
||||
session = Session(bind=op.get_bind())
|
||||
new_slack_bot = SlackBot(
|
||||
name="Slack Bot (Migrated)",
|
||||
enabled=True,
|
||||
bot_token=bot_token,
|
||||
app_token=app_token,
|
||||
)
|
||||
session.add(new_slack_bot)
|
||||
session.commit()
|
||||
first_row_id = new_slack_bot.id
|
||||
|
||||
# Create a default bot if none exists
|
||||
# This is in case there are no slack tokens but there are channels configured
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
|
||||
SELECT 'Default Bot', true, '', ''
|
||||
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
|
||||
RETURNING id;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Get the bot ID to use (either from existing migration or newly created)
|
||||
bot_id_query = sa.text(
|
||||
"""
|
||||
SELECT COALESCE(
|
||||
:first_row_id,
|
||||
(SELECT id FROM slack_bot ORDER BY id ASC LIMIT 1)
|
||||
) as bot_id;
|
||||
"""
|
||||
)
|
||||
result = op.get_bind().execute(bot_id_query, {"first_row_id": first_row_id})
|
||||
bot_id = result.scalar()
|
||||
|
||||
# CTE (Common Table Expression) that transforms the old slack_bot_config table data
|
||||
# This splits up the channel_names into their own rows
|
||||
channel_names_cte = """
|
||||
WITH channel_names AS (
|
||||
SELECT
|
||||
sbc.id as config_id,
|
||||
sbc.persona_id,
|
||||
sbc.response_type,
|
||||
sbc.enable_auto_filters,
|
||||
jsonb_array_elements_text(sbc.channel_config->'channel_names') as channel_name,
|
||||
sbc.channel_config->>'respond_tag_only' as respond_tag_only,
|
||||
sbc.channel_config->>'respond_to_bots' as respond_to_bots,
|
||||
sbc.channel_config->'respond_member_group_list' as respond_member_group_list,
|
||||
sbc.channel_config->'answer_filters' as answer_filters,
|
||||
sbc.channel_config->'follow_up_tags' as follow_up_tags
|
||||
FROM slack_bot_config sbc
|
||||
)
|
||||
"""
|
||||
|
||||
# Insert the channel names into the new slack_channel_config table
|
||||
insert_statement = """
|
||||
INSERT INTO slack_channel_config (
|
||||
slack_bot_id,
|
||||
persona_id,
|
||||
channel_config,
|
||||
response_type,
|
||||
enable_auto_filters
|
||||
)
|
||||
SELECT
|
||||
:bot_id,
|
||||
channel_name.persona_id,
|
||||
jsonb_build_object(
|
||||
'channel_name', channel_name.channel_name,
|
||||
'respond_tag_only',
|
||||
COALESCE((channel_name.respond_tag_only)::boolean, false),
|
||||
'respond_to_bots',
|
||||
COALESCE((channel_name.respond_to_bots)::boolean, false),
|
||||
'respond_member_group_list',
|
||||
COALESCE(channel_name.respond_member_group_list, '[]'::jsonb),
|
||||
'answer_filters',
|
||||
COALESCE(channel_name.answer_filters, '[]'::jsonb),
|
||||
'follow_up_tags',
|
||||
COALESCE(channel_name.follow_up_tags, '[]'::jsonb)
|
||||
),
|
||||
channel_name.response_type,
|
||||
channel_name.enable_auto_filters
|
||||
FROM channel_names channel_name;
|
||||
"""
|
||||
|
||||
op.execute(sa.text(channel_names_cte + insert_statement).bindparams(bot_id=bot_id))
|
||||
|
||||
# Clean up old tokens if they existed
|
||||
try:
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Removing old bot and app tokens.")
|
||||
get_kv_store().delete("slack_bot_tokens_config_key")
|
||||
except Exception:
|
||||
logger.warning("tried to delete tokens in dynamic config but failed")
|
||||
# Rename the table
|
||||
op.rename_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
"slack_channel_config__standard_answer_category",
|
||||
)
|
||||
|
||||
# Rename the column
|
||||
op.alter_column(
|
||||
"slack_channel_config__standard_answer_category",
|
||||
"slack_bot_config_id",
|
||||
new_column_name="slack_channel_config_id",
|
||||
)
|
||||
|
||||
# Drop the table with CASCADE to handle dependent objects
|
||||
op.execute("DROP TABLE slack_bot_config CASCADE")
|
||||
|
||||
logger.info(f"{revision}: Migration complete.")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Recreate the old slack_bot_config table
|
||||
op.create_table(
|
||||
"slack_bot_config",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("response_type", sa.String(), nullable=False),
|
||||
sa.Column("enable_auto_filters", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Migrate data back to the old format
|
||||
# Group by persona_id to combine channel names back into arrays
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_bot_config (
|
||||
persona_id,
|
||||
channel_config,
|
||||
response_type,
|
||||
enable_auto_filters
|
||||
)
|
||||
SELECT DISTINCT ON (persona_id)
|
||||
persona_id,
|
||||
jsonb_build_object(
|
||||
'channel_names', (
|
||||
SELECT jsonb_agg(c.channel_config->>'channel_name')
|
||||
FROM slack_channel_config c
|
||||
WHERE c.persona_id = scc.persona_id
|
||||
),
|
||||
'respond_tag_only', (channel_config->>'respond_tag_only')::boolean,
|
||||
'respond_to_bots', (channel_config->>'respond_to_bots')::boolean,
|
||||
'respond_member_group_list', channel_config->'respond_member_group_list',
|
||||
'answer_filters', channel_config->'answer_filters',
|
||||
'follow_up_tags', channel_config->'follow_up_tags'
|
||||
),
|
||||
response_type,
|
||||
enable_auto_filters
|
||||
FROM slack_channel_config scc
|
||||
WHERE persona_id IS NOT NULL;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Rename the table back
|
||||
op.rename_table(
|
||||
"slack_channel_config__standard_answer_category",
|
||||
"slack_bot_config__standard_answer_category",
|
||||
)
|
||||
|
||||
# Rename the column back
|
||||
op.alter_column(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
"slack_channel_config_id",
|
||||
new_column_name="slack_bot_config_id",
|
||||
)
|
||||
|
||||
# Try to save the first bot's tokens back to KV store
|
||||
try:
|
||||
first_bot = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.text(
|
||||
"SELECT bot_token, app_token FROM slack_bot ORDER BY id LIMIT 1"
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if first_bot and first_bot.bot_token and first_bot.app_token:
|
||||
tokens = {
|
||||
"bot_token": first_bot.bot_token,
|
||||
"app_token": first_bot.app_token,
|
||||
}
|
||||
get_kv_store().store("slack_bot_tokens_config_key", tokens)
|
||||
except Exception:
|
||||
logger.warning("Failed to save tokens back to KV store")
|
||||
|
||||
# Drop the new tables in reverse order
|
||||
op.drop_table("slack_channel_config")
|
||||
op.drop_table("slack_bot")
|
||||
23
backend/alembic/versions/54a74a0417fc_danswerbot_onyxbot.py
Normal file
23
backend/alembic/versions/54a74a0417fc_danswerbot_onyxbot.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""danswerbot -> onyxbot
|
||||
|
||||
Revision ID: 54a74a0417fc
|
||||
Revises: 94dc3d0236f8
|
||||
Create Date: 2024-12-11 18:05:05.490737
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "54a74a0417fc"
|
||||
down_revision = "94dc3d0236f8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("chat_session", "danswerbot_flow", new_column_name="onyxbot_flow")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("chat_session", "onyxbot_flow", new_column_name="danswerbot_flow")
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Track Danswerbot Explicitly
|
||||
"""Track Onyxbot Explicitly
|
||||
|
||||
Revision ID: 570282d33c49
|
||||
Revises: 7547d982db8f
|
||||
45
backend/alembic/versions/6d562f86c78b_remove_default_bot.py
Normal file
45
backend/alembic/versions/6d562f86c78b_remove_default_bot.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""remove default bot
|
||||
|
||||
Revision ID: 6d562f86c78b
|
||||
Revises: 177de57c21c9
|
||||
Create Date: 2024-11-22 11:51:29.331336
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6d562f86c78b"
|
||||
down_revision = "177de57c21c9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM slack_bot
|
||||
WHERE name = 'Default Bot'
|
||||
AND bot_token = ''
|
||||
AND app_token = ''
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM slack_channel_config
|
||||
WHERE slack_channel_config.slack_bot_id = slack_bot.id
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
|
||||
SELECT 'Default Bot', true, '', ''
|
||||
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
|
||||
RETURNING id;
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -9,7 +9,7 @@ import json
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "703313b75876"
|
||||
|
||||
@@ -8,9 +8,9 @@ Create Date: 2024-03-22 21:34:27.629444
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -18,7 +18,7 @@ depends_on: None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
# In a PR:
|
||||
# https://github.com/danswer-ai/danswer/pull/397/files#diff-f05fb341f6373790b91852579631b64ca7645797a190837156a282b67e5b19c2
|
||||
# https://github.com/onyx-dot-app/onyx/pull/397/files#diff-f05fb341f6373790b91852579631b64ca7645797a190837156a282b67e5b19c2
|
||||
# we directly changed some previous migrations. This caused some users to have native enums
|
||||
# while others wouldn't. This has caused some issues when adding new fields to these enums.
|
||||
# This migration manually changes the enum types to ensure that nobody uses native enums.
|
||||
|
||||
45
backend/alembic/versions/91a0a4d62b14_milestone.py
Normal file
45
backend/alembic/versions/91a0a4d62b14_milestone.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Milestone
|
||||
|
||||
Revision ID: 91a0a4d62b14
|
||||
Revises: dab04867cd88
|
||||
Create Date: 2024-12-13 19:03:30.947551
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91a0a4d62b14"
|
||||
down_revision = "dab04867cd88"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"milestone",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("tenant_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("event_type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("milestone")
|
||||
@@ -7,7 +7,7 @@ Create Date: 2024-03-21 12:05:23.956734
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91fd3b470d1a"
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""add web ui option to slack config
|
||||
|
||||
Revision ID: 93560ba1b118
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-24 06:36:17.490612
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93560ba1b118"
|
||||
down_revision = "6d562f86c78b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add show_continue_in_web_ui with default False to all existing channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
|
||||
WHERE NOT channel_config ? 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove show_continue_in_web_ui from all channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config - 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
@@ -7,15 +7,15 @@ Create Date: 2024-10-26 13:06:06.937969
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
|
||||
# Import your models and constants
|
||||
from danswer.db.models import (
|
||||
from onyx.db.models import (
|
||||
Connector,
|
||||
ConnectorCredentialPair,
|
||||
Credential,
|
||||
IndexAttempt,
|
||||
)
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -30,13 +30,11 @@ def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
connectors_to_delete = (
|
||||
session.query(Connector)
|
||||
.filter(Connector.source == DocumentSource.REQUESTTRACKER)
|
||||
.all()
|
||||
# Get connectors using raw SQL
|
||||
result = bind.execute(
|
||||
text("SELECT id FROM connector WHERE source = 'requesttracker'")
|
||||
)
|
||||
|
||||
connector_ids = [connector.id for connector in connectors_to_delete]
|
||||
connector_ids = [row[0] for row in result]
|
||||
|
||||
if connector_ids:
|
||||
cc_pairs_to_delete = (
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""make document set description optional
|
||||
|
||||
Revision ID: 94dc3d0236f8
|
||||
Revises: bf7a81109301
|
||||
Create Date: 2024-12-11 11:26:10.616722
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "94dc3d0236f8"
|
||||
down_revision = "bf7a81109301"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make document_set.description column nullable
|
||||
op.alter_column(
|
||||
"document_set", "description", existing_type=sa.String(), nullable=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert document_set.description column to non-nullable
|
||||
op.alter_column(
|
||||
"document_set", "description", existing_type=sa.String(), nullable=False
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add creator to cc pair
|
||||
|
||||
Revision ID: 9cf5c00f72fe
|
||||
Revises: 26b931506ecb
|
||||
Create Date: 2024-11-12 15:16:42.682902
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9cf5c00f72fe"
|
||||
down_revision = "26b931506ecb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"creator_id",
|
||||
sa.UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "creator_id")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Combine Search and Chat
|
||||
|
||||
Revision ID: 9f696734098f
|
||||
Revises: a8c2065484e6
|
||||
Create Date: 2024-11-27 15:32:19.694972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9f696734098f"
|
||||
down_revision = "a8c2065484e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("chat_session", "description", nullable=True)
|
||||
op.drop_column("chat_session", "one_shot")
|
||||
op.drop_column("slack_channel_config", "response_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
|
||||
op.alter_column("chat_session", "description", nullable=False)
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
|
||||
)
|
||||
op.add_column(
|
||||
"slack_channel_config",
|
||||
sa.Column(
|
||||
"response_type", sa.String(), nullable=False, server_default="citations"
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add auto scroll to user model
|
||||
|
||||
Revision ID: a8c2065484e6
|
||||
Revises: abe7378b8217
|
||||
Create Date: 2024-11-22 17:34:09.690295
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a8c2065484e6"
|
||||
down_revision = "abe7378b8217"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "auto_scroll")
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add indexing trigger to cc_pair
|
||||
|
||||
Revision ID: abe7378b8217
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-26 19:09:53.481171
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abe7378b8217"
|
||||
down_revision = "93560ba1b118"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"indexing_trigger",
|
||||
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "indexing_trigger")
|
||||
@@ -10,7 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import ENUM
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b156fa702355"
|
||||
@@ -288,6 +288,15 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
|
||||
# below
|
||||
op.execute("DELETE FROM chat_feedback")
|
||||
op.execute("DELETE FROM chat_message__search_doc")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM chat_message")
|
||||
op.execute("DELETE FROM chat_session")
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
"""remove description from starter messages
|
||||
|
||||
Revision ID: b72ed7a5db0e
|
||||
Revises: 33cb72ea4d80
|
||||
Create Date: 2024-11-03 15:55:28.944408
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b72ed7a5db0e"
|
||||
down_revision = "33cb72ea4d80"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET starter_messages = (
|
||||
SELECT jsonb_agg(elem - 'description')
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET starter_messages = (
|
||||
SELECT jsonb_agg(elem || '{"description": ""}')
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
"""delete_input_prompts
|
||||
|
||||
Revision ID: bf7a81109301
|
||||
Revises: f7a894b06d02
|
||||
Create Date: 2024-12-09 12:00:49.884228
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bf7a81109301"
|
||||
down_revision = "f7a894b06d02"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
87
backend/alembic/versions/c0aab6edb6dd_delete_workspace.py
Normal file
87
backend/alembic/versions/c0aab6edb6dd_delete_workspace.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""delete workspace
|
||||
|
||||
Revision ID: c0aab6edb6dd
|
||||
Revises: 35e518e0ddf4
|
||||
Create Date: 2024-12-17 14:37:07.660631
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0aab6edb6dd"
|
||||
down_revision = "35e518e0ddf4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = connector_specific_config - 'workspace'
|
||||
WHERE source = 'SLACK'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
import json
|
||||
from sqlalchemy import text
|
||||
from slack_sdk import WebClient
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Fetch all Slack credentials
|
||||
creds_result = conn.execute(
|
||||
text("SELECT id, credential_json FROM credential WHERE source = 'SLACK'")
|
||||
)
|
||||
all_slack_creds = creds_result.fetchall()
|
||||
if not all_slack_creds:
|
||||
return
|
||||
|
||||
for cred_row in all_slack_creds:
|
||||
credential_id, credential_json = cred_row
|
||||
|
||||
credential_json = (
|
||||
credential_json.tobytes().decode("utf-8")
|
||||
if isinstance(credential_json, memoryview)
|
||||
else credential_json.decode("utf-8")
|
||||
)
|
||||
credential_data = json.loads(credential_json)
|
||||
slack_bot_token = credential_data.get("slack_bot_token")
|
||||
if not slack_bot_token:
|
||||
print(
|
||||
f"No slack_bot_token found for credential {credential_id}. "
|
||||
"Your Slack connector will not function until you upgrade and provide a valid token."
|
||||
)
|
||||
continue
|
||||
|
||||
client = WebClient(token=slack_bot_token)
|
||||
try:
|
||||
auth_response = client.auth_test()
|
||||
workspace = auth_response["url"].split("//")[1].split(".")[0]
|
||||
|
||||
# Update only the connectors linked to this credential
|
||||
# (and which are Slack connectors).
|
||||
op.execute(
|
||||
f"""
|
||||
UPDATE connector AS c
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{{workspace}}',
|
||||
to_jsonb('{workspace}'::text)
|
||||
)
|
||||
FROM connector_credential_pair AS ccp
|
||||
WHERE ccp.connector_id = c.id
|
||||
AND c.source = 'SLACK'
|
||||
AND ccp.credential_id = {credential_id}
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
f"We were unable to get the workspace url for your Slack Connector with id {credential_id}."
|
||||
)
|
||||
print("This connector will no longer work until you upgrade.")
|
||||
continue
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add recent assistants
|
||||
|
||||
Revision ID: c0fd6e4da83a
|
||||
Revises: b72ed7a5db0e
|
||||
Create Date: 2024-11-03 17:28:54.916618
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0fd6e4da83a"
|
||||
down_revision = "b72ed7a5db0e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
@@ -23,6 +23,56 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete chat messages and feedback first since they reference chat sessions
|
||||
# Get chat messages from sessions with null persona_id
|
||||
chat_messages_query = """
|
||||
SELECT id
|
||||
FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
|
||||
# Delete dependent records first
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM document_retrieval_feedback
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Delete chat messages
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Now we can safely delete the chat sessions
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Add composite index to document_by_connector_credential_pair
|
||||
|
||||
Revision ID: dab04867cd88
|
||||
Revises: 54a74a0417fc
|
||||
Create Date: 2024-12-13 22:43:20.119990
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dab04867cd88"
|
||||
down_revision = "54a74a0417fc"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Composite index on (connector_id, credential_id)
|
||||
op.create_index(
|
||||
"idx_document_cc_pair_connector_credential",
|
||||
"document_by_connector_credential_pair",
|
||||
["connector_id", "credential_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"idx_document_cc_pair_connector_credential",
|
||||
table_name="document_by_connector_credential_pair",
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Danswer Custom Tool Flow
|
||||
"""Onyx Custom Tool Flow
|
||||
|
||||
Revision ID: dba7f71618f5
|
||||
Revises: d5645c915d0e
|
||||
@@ -9,12 +9,12 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from danswer.db.search_settings import (
|
||||
from onyx.db.search_settings import (
|
||||
get_new_default_embedding_model,
|
||||
get_old_default_embedding_model,
|
||||
user_has_overridden_embedding_model,
|
||||
)
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dbaa756c2ccf"
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""extended_role_for_non_web
|
||||
|
||||
Revision ID: dfbe9e93d3c7
|
||||
Revises: 9cf5c00f72fe
|
||||
Create Date: 2024-11-16 07:54:18.727906
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dfbe9e93d3c7"
|
||||
down_revision = "9cf5c00f72fe"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET role = 'EXT_PERM_USER'
|
||||
WHERE has_web_login = false
|
||||
"""
|
||||
)
|
||||
op.drop_column("user", "has_web_login")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET has_web_login = false,
|
||||
role = 'BASIC'
|
||||
WHERE role IN ('SLACK_USER', 'EXT_PERM_USER')
|
||||
"""
|
||||
)
|
||||
@@ -8,7 +8,7 @@ Create Date: 2024-03-14 18:06:08.523106
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e50154680a5c"
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""non-nullbale slack bot id in channel config
|
||||
|
||||
Revision ID: f7a894b06d02
|
||||
Revises: 9f696734098f
|
||||
Create Date: 2024-12-06 12:55:42.845723
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7a894b06d02"
|
||||
down_revision = "9f696734098f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Delete all rows with null slack_bot_id
|
||||
op.execute("DELETE FROM slack_channel_config WHERE slack_bot_id IS NULL")
|
||||
|
||||
# Make slack_bot_id non-nullable
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Make slack_bot_id nullable again
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
@@ -1,3 +1,3 @@
|
||||
These files are for public table migrations when operating with multi tenancy.
|
||||
|
||||
If you are not a Danswer developer, you can ignore this directory entirely.
|
||||
If you are not a Onyx developer, you can ignore this directory entirely.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
@@ -7,8 +8,8 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
from alembic import context
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import PublicBase
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.models import PublicBase
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@@ -37,8 +38,15 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str,
|
||||
type_: str,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
import os
|
||||
|
||||
__version__ = os.environ.get("DANSWER_VERSION", "") or "Development"
|
||||
@@ -1,289 +0,0 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
"""We handle this signal in order to remove completed tasks
|
||||
from their respective tasksets. This allows us to track the progress of document set
|
||||
and user group syncs.
|
||||
|
||||
This function runs after any task completes (both success and failure)
|
||||
Note that this signal does not fire on a task that failed to complete and is going
|
||||
to be retried.
|
||||
|
||||
This also does not fire if a worker with acks_late=False crashes (which all of our
|
||||
long running workers are)
|
||||
"""
|
||||
if not task:
|
||||
return
|
||||
|
||||
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
||||
|
||||
if state not in READY_STATES:
|
||||
return
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = None
|
||||
else:
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(tenant_id, int(document_set_id))
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(tenant_id, int(usergroup_id))
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDelete.PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
return
|
||||
|
||||
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
time_start = time.monotonic()
|
||||
logger.info("Redis: Readiness check starting.")
|
||||
while True:
|
||||
try:
|
||||
if r.ping():
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Redis: Readiness check succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
# Exit early if multi-tenant since primary worker check not needed
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_redis_client(tenant_id=None)
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
while True:
|
||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||
break
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Primary worker was not ready within the timeout. "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
task_logger.info("worker_ready signal received.")
|
||||
|
||||
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock = sender.primary_worker_lock
|
||||
try:
|
||||
if lock.owned():
|
||||
try:
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release primary worker lock: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if primary worker lock is owned: {e}")
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
# TODO: could unhardcode format and colorize and accept these as options from
|
||||
# celery's config
|
||||
|
||||
# reformats the root logger
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
||||
root_formatter = ColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_handler.setFormatter(root_formatter)
|
||||
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
||||
|
||||
if logfile:
|
||||
root_file_handler = logging.FileHandler(logfile)
|
||||
root_file_formatter = PlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_file_handler.setFormatter(root_file_formatter)
|
||||
root_logger.addHandler(root_file_handler)
|
||||
|
||||
root_logger.setLevel(loglevel)
|
||||
|
||||
# reformats celery's task logger
|
||||
task_formatter = CeleryTaskColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
||||
task_handler.setFormatter(task_formatter)
|
||||
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
||||
|
||||
if logfile:
|
||||
task_file_handler = logging.FileHandler(logfile)
|
||||
task_file_formatter = CeleryTaskPlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_file_handler.setFormatter(task_file_formatter)
|
||||
task_logger.addHandler(task_file_handler)
|
||||
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
# hide celery task received spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
|
||||
strategy.logger.setLevel(logging.WARNING)
|
||||
|
||||
# hide celery task succeeded/failed spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
|
||||
trace.logger.setLevel(logging.WARNING)
|
||||
@@ -1,96 +0,0 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=1)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
super().setup_schedule()
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
now = self.app.now()
|
||||
if (
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reloading schedule to check for new tenants...")
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Checking for tenant task updates...")
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
|
||||
new_beat_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
current_schedule = getattr(self, "_store", {"entries": {}}).get(
|
||||
"entries", {}
|
||||
)
|
||||
|
||||
existing_tenants = set()
|
||||
for task_name in current_schedule.keys():
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Found new tenant: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
new_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
new_task["options"] = options
|
||||
new_beat_schedule[task_name] = new_task
|
||||
|
||||
if self._should_update_schedule(current_schedule, new_beat_schedule):
|
||||
logger.info(
|
||||
"Updating schedule",
|
||||
extra={
|
||||
"new_tasks": len(new_beat_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
if not hasattr(self, "_store"):
|
||||
self._store: dict[str, dict] = {"entries": {}}
|
||||
self.update_from_dict(new_beat_schedule)
|
||||
logger.info(f"New schedule: {new_beat_schedule}")
|
||||
|
||||
logger.info("Tenant tasks updated successfully")
|
||||
else:
|
||||
logger.debug("No schedule updates needed")
|
||||
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
) -> bool:
|
||||
"""Compare schedules to determine if an update is needed."""
|
||||
current_tasks = set(current_schedule.keys())
|
||||
new_tasks = set(new_schedule.keys())
|
||||
return current_tasks != new_tasks
|
||||
@@ -1,25 +0,0 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
from typing import cast
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from danswer.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
used to implement task prioritization.
|
||||
This operation is not atomic."""
|
||||
total_length = 0
|
||||
for i in range(len(DanswerCeleryPriority)):
|
||||
queue_name = queue
|
||||
if i > 0:
|
||||
queue_name += CELERY_SEPARATOR
|
||||
queue_name += str(i)
|
||||
|
||||
length = r.llen(queue_name)
|
||||
total_length += cast(int, length)
|
||||
|
||||
return total_length
|
||||
@@ -1,48 +0,0 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
@@ -1,628 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class RunIndexingCallback(RunIndexingCallbackInterface):
|
||||
def __init__(
|
||||
self,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: redis.lock.Lock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_lock: redis.lock.Lock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
return False
|
||||
|
||||
def progress(self, amount: int) -> None:
|
||||
self.redis_lock.reacquire()
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
tasks_created = 0
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
if current_search_settings.provider_type is None and not MULTI_TENANT:
|
||||
if old_search_settings:
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
# only warm up if search settings were changed
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
for search_settings_instance in search_settings:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id, db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
if not _should_index(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
self.app,
|
||||
cc_pair,
|
||||
search_settings_instance,
|
||||
False,
|
||||
db_session,
|
||||
r,
|
||||
tenant_id,
|
||||
)
|
||||
if attempt_id:
|
||||
task_logger.info(
|
||||
f"Indexing queued: cc_pair={cc_pair.id} index_attempt={attempt_id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return tasks_created
|
||||
|
||||
|
||||
def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Checks various global settings and past indexing attempts to determine if
|
||||
we should try to start indexing the cc pair / search setting combination.
|
||||
|
||||
Note that tactical checks such as preventing overlap with a currently running task
|
||||
are not handled here.
|
||||
|
||||
Return True if we should try to index, False if not.
|
||||
"""
|
||||
connector = cc_pair.connector
|
||||
|
||||
# uncomment for debugging
|
||||
# task_logger.info(f"_should_index: "
|
||||
# f"cc_pair={cc_pair.id} "
|
||||
# f"connector={cc_pair.connector_id} "
|
||||
# f"refresh_freq={connector.refresh_freq}")
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if (
|
||||
not cc_pair.status.is_active()
|
||||
or connector.id == 0
|
||||
or connector.source == DocumentSource.INGESTION_API
|
||||
):
|
||||
return False
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
if time_since_index.total_seconds() < connector.refresh_freq:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_indexing_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
# skip if already indexing
|
||||
if redis_connector_index.fenced:
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector_index.generator_clear()
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorIndexingFenceData(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
# code elsewhere checks for index attempts without an associated redis key
|
||||
# and cleans them up
|
||||
# therefore we must create the attempt and the task after the fence goes up
|
||||
index_attempt_id = create_index_attempt(
|
||||
cc_pair.id,
|
||||
search_settings.id,
|
||||
from_beginning=reindex,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
result = celery_app.send_task(
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_INDEXING,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
redis_connector_index.set_fence(payload)
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
|
||||
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
|
||||
def connector_indexing_proxy_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
task_logger.info(
|
||||
f"Indexing proxy - starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
connector_indexing_task,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
tenant_id,
|
||||
global_version.is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
|
||||
if not job:
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
while True:
|
||||
sleep(10)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
f"Indexing proxy - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
job.release()
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def connector_indexing_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
is_ee: bool,
|
||||
) -> int | None:
|
||||
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list
|
||||
|
||||
acks_late must be set to False. Otherwise, celery's visibility timeout will
|
||||
cause any task that runs longer than the timeout to be redispatched by the broker.
|
||||
There appears to be no good workaround for this, so we need to handle redispatching
|
||||
manually.
|
||||
|
||||
Returns None if the task did not run (possibly due to a conflict).
|
||||
Otherwise, returns an int >= 0 representing the number of indexed docs.
|
||||
|
||||
NOTE: if an exception is raised out of this task, the primary worker will detect
|
||||
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
|
||||
This will cause the primary worker to abort the indexing attempt and clean up.
|
||||
"""
|
||||
logger.info(
|
||||
f"Indexing spawned task starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
attempt = None
|
||||
n_final_progress: int | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}"
|
||||
)
|
||||
|
||||
if redis_connector.stop.fenced:
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
while True:
|
||||
# wait for the fence to come up
|
||||
if not redis_connector_index.fenced:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
raise ValueError("connector_indexing_task: payload invalid or not found")
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_indexing_task - Waiting for fence: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_indexing_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise ValueError(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
|
||||
|
||||
if not cc_pair.connector:
|
||||
raise ValueError(
|
||||
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
|
||||
)
|
||||
|
||||
if not cc_pair.credential:
|
||||
raise ValueError(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
|
||||
# define a callback class
|
||||
callback = RunIndexingCallback(
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# get back the total number of indexed docs and return it
|
||||
n_final_progress = redis_connector_index.get_progress()
|
||||
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing spawned task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
|
||||
redis_connector_index.reset()
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return n_final_progress
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from danswer.background.celery.apps.beat import celery_app
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = celery_app
|
||||
@@ -1,8 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.primary", "celery_app"
|
||||
)
|
||||
@@ -1,24 +0,0 @@
|
||||
input_prompts:
|
||||
- id: -5
|
||||
prompt: "Elaborate"
|
||||
content: "Elaborate on the above, give me a more in depth explanation."
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -4
|
||||
prompt: "Reword"
|
||||
content: "Help me rewrite the following politely and concisely for professional communication:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -3
|
||||
prompt: "Email"
|
||||
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -2
|
||||
prompt: "Debug"
|
||||
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
@@ -1,185 +0,0 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
"""This contains the minimal set information for the LLM portion including citations"""
|
||||
|
||||
document_id: str
|
||||
content: str
|
||||
blurb: str
|
||||
semantic_identifier: str
|
||||
source_type: DocumentSource
|
||||
metadata: dict[str, str | list[str]]
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs):
|
||||
rephrased_query: str | None = None
|
||||
predicted_flow: QueryFlow | None
|
||||
predicted_search: SearchType | None
|
||||
applied_source_filters: list[DocumentSource] | None
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
initial_dict["applied_time_cutoff"] = (
|
||||
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
||||
)
|
||||
|
||||
return initial_dict
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
data["stop_reason"] = self.stop_reason.name
|
||||
return data
|
||||
|
||||
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int]
|
||||
|
||||
|
||||
class FinalUsedContextDocsResponse(BaseModel):
|
||||
final_context_docs: list[LlmDoc]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
relevant: bool
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class SectionRelevancePiece(RelevanceAnalysis):
|
||||
"""LLM analysis mapped to an Inference Section"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int # ID of the center chunk for a given inference section
|
||||
|
||||
|
||||
class DocumentRelevance(BaseModel):
|
||||
"""Contains all relevance information for a given search"""
|
||||
|
||||
relevance_summaries: dict[str, RelevanceAnalysis]
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
|
||||
|
||||
# An intermediate representation of citations, later translated into
|
||||
# a mapping of the citation [n] number to SearchDoc
|
||||
class CitationInfo(BaseModel):
|
||||
citation_num: int
|
||||
document_id: str
|
||||
|
||||
|
||||
class AllCitations(BaseModel):
|
||||
citations: list[CitationInfo]
|
||||
|
||||
|
||||
# This is a mapping of the citation number to the document index within
|
||||
# the result search doc set
|
||||
class MessageSpecificCitations(BaseModel):
|
||||
citation_map: dict[int, int]
|
||||
|
||||
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
class DanswerContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerContexts(BaseModel):
|
||||
contexts: list[DanswerContext]
|
||||
|
||||
|
||||
class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class QAResponse(SearchResponse, DanswerAnswer):
|
||||
quotes: list[DanswerQuote] | None
|
||||
contexts: list[DanswerContexts] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
class CustomToolResponse(BaseModel):
|
||||
response: ToolResultType
|
||||
tool_name: str
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| FileChatDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
|
||||
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
@@ -1,115 +0,0 @@
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
|
||||
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
|
||||
from danswer.prompts.chat_tools import TOOL_TEMPLATE
|
||||
from danswer.prompts.chat_tools import USER_INPUT
|
||||
|
||||
|
||||
class ToolInfo(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class DanswerChatModelOut(BaseModel):
|
||||
model_raw: str
|
||||
action: str
|
||||
action_input: str
|
||||
|
||||
|
||||
def call_tool(
|
||||
model_actions: DanswerChatModelOut,
|
||||
) -> str:
|
||||
raise NotImplementedError("There are no additional tool integrations right now")
|
||||
|
||||
|
||||
def form_user_prompt_text(
|
||||
query: str,
|
||||
tool_text: str | None,
|
||||
hint_text: str | None,
|
||||
user_input_prompt: str = USER_INPUT,
|
||||
tool_less_prompt: str = TOOL_LESS_PROMPT,
|
||||
) -> str:
|
||||
user_prompt = tool_text or tool_less_prompt
|
||||
|
||||
user_prompt += user_input_prompt.format(user_input=query)
|
||||
|
||||
if hint_text:
|
||||
if user_prompt[-1] != "\n":
|
||||
user_prompt += "\n"
|
||||
user_prompt += "\nHint: " + hint_text
|
||||
|
||||
return user_prompt.strip()
|
||||
|
||||
|
||||
def form_tool_section_text(
|
||||
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
|
||||
) -> str | None:
|
||||
if not tools and not retrieval_enabled:
|
||||
return None
|
||||
|
||||
if retrieval_enabled and tools:
|
||||
tools.append(
|
||||
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
|
||||
)
|
||||
|
||||
tools_intro = []
|
||||
if tools:
|
||||
num_tools = len(tools)
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||
|
||||
prefix = "Must be one of " if num_tools > 1 else "Must be "
|
||||
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
return template.format(
|
||||
tool_overviews=tools_intro_text, tool_names=tool_names_text
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_FOLLOWUP,
|
||||
ignore_hint: bool = False,
|
||||
) -> str:
|
||||
# If multi-line query, it likely confuses the model more than helps
|
||||
if "\n" not in query:
|
||||
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
|
||||
else:
|
||||
optional_reminder = ""
|
||||
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"\nHint: {hint_text}\n"
|
||||
else:
|
||||
hint_text_spaced = ""
|
||||
|
||||
return tool_followup_prompt.format(
|
||||
tool_output=tool_output,
|
||||
optional_reminder=optional_reminder,
|
||||
hint=hint_text_spaced,
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_less_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
|
||||
) -> str:
|
||||
hint = f"Hint: {hint_text}" if hint_text else ""
|
||||
return tool_followup_prompt.format(
|
||||
context_str=tool_output, user_query=query, hint_text=hint
|
||||
).strip()
|
||||
@@ -1,226 +0,0 @@
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 3600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 100
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
"""
|
||||
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
|
||||
This is necessary because the default Confluence class does not properly support cql expansions.
|
||||
All methods are automatically wrapped with handle_confluence_rate_limit.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
wrap it with handle_confluence_rate_limit.
|
||||
"""
|
||||
for attr_name in dir(self):
|
||||
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
setattr(
|
||||
self,
|
||||
attr_name,
|
||||
handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
)
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
"""
|
||||
if not limit:
|
||||
limit = _DEFAULT_PAGINATION_LIMIT
|
||||
|
||||
connection_char = "&" if "?" in url_suffix else "?"
|
||||
url_suffix += f"{connection_char}limit={limit}"
|
||||
|
||||
while url_suffix:
|
||||
try:
|
||||
next_response = self.get(url_suffix)
|
||||
except Exception as e:
|
||||
logger.exception("Error in danswer_cql: \n")
|
||||
raise e
|
||||
yield next_response.get("results", [])
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
return self._paginate_url("rest/api/group", limit)
|
||||
|
||||
def paginated_group_members_retrieval(
|
||||
self,
|
||||
group_name: str,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
group_name = quote(group_name)
|
||||
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
def paginated_cql_user_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
f"rest/api/search/user?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def paginated_cql_page_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
f"rest/api/content/search?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def cql_paginate_all_expansions(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""
|
||||
This function will paginate through the top level query first, then
|
||||
paginate through all of the expansions.
|
||||
The limit only applies to the top level query.
|
||||
All expansion paginations use default pagination limit (defined by Atlassian).
|
||||
"""
|
||||
|
||||
def _traverse_and_update(data: dict | list) -> None:
|
||||
if isinstance(data, dict):
|
||||
next_url = data.get("_links", {}).get("next")
|
||||
if next_url and "results" in data:
|
||||
data["results"].extend(self._paginate_url(next_url))
|
||||
|
||||
for value in data.values():
|
||||
_traverse_and_update(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
_traverse_and_update(item)
|
||||
|
||||
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
|
||||
_traverse_and_update(results)
|
||||
yield results
|
||||
@@ -1,321 +0,0 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
PROJECT_URL_PAT = "projects"
|
||||
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
|
||||
|
||||
|
||||
def extract_jira_project(url: str) -> tuple[str, str]:
|
||||
parsed_url = urlparse(url)
|
||||
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
|
||||
|
||||
# Split the path by '/' and find the position of 'projects' to get the project name
|
||||
split_path = parsed_url.path.split("/")
|
||||
if PROJECT_URL_PAT in split_path:
|
||||
project_pos = split_path.index(PROJECT_URL_PAT)
|
||||
if len(split_path) > project_pos + 1:
|
||||
jira_project = split_path[project_pos + 1]
|
||||
else:
|
||||
raise ValueError("No project name found in the URL")
|
||||
else:
|
||||
raise ValueError("'projects' not found in the URL")
|
||||
|
||||
return jira_base, jira_project
|
||||
|
||||
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
texts = []
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
texts.append(item["text"])
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
comment_strs = []
|
||||
for comment in jira.fields.comment.comments:
|
||||
try:
|
||||
body_text = (
|
||||
comment.body
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(comment.raw["body"])
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(comment, "author")
|
||||
and hasattr(comment.author, "emailAddress")
|
||||
and comment.author.emailAddress in comment_email_blacklist
|
||||
):
|
||||
continue # Skip adding comment if author's email is in blacklist
|
||||
|
||||
comment_strs.append(body_text)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process comment due to an error: {e}")
|
||||
continue
|
||||
|
||||
return comment_strs
|
||||
|
||||
|
||||
def fetch_jira_issues_batch(
|
||||
jql: str,
|
||||
start_index: int,
|
||||
jira_client: JIRA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
) -> tuple[list[Document], int]:
|
||||
doc_batch = []
|
||||
|
||||
batch = jira_client.search_issues(
|
||||
jql,
|
||||
startAt=start_index,
|
||||
maxResults=batch_size,
|
||||
)
|
||||
|
||||
for jira in batch:
|
||||
if type(jira) != Issue:
|
||||
logger.warning(f"Found Jira object not of type Issue {jira}")
|
||||
continue
|
||||
|
||||
if labels_to_skip and any(
|
||||
label in jira.fields.labels for label in labels_to_skip
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping {jira.key} because it has a label to skip. Found "
|
||||
f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
continue
|
||||
|
||||
description = (
|
||||
jira.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {jira.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
continue
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
people = set()
|
||||
try:
|
||||
people.add(
|
||||
BasicExpertInfo(
|
||||
display_name=jira.fields.creator.displayName,
|
||||
email=jira.fields.creator.emailAddress,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
try:
|
||||
people.add(
|
||||
BasicExpertInfo(
|
||||
display_name=jira.fields.assignee.displayName, # type: ignore
|
||||
email=jira.fields.assignee.emailAddress, # type: ignore
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=jira.fields.summary,
|
||||
doc_updated_at=time_str_to_utc(jira.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
)
|
||||
return doc_batch, len(batch)
|
||||
|
||||
|
||||
class JiraConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
jira_project_url: str,
|
||||
comment_email_blacklist: list[str] | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
# if a ticket has one of the labels specified in this list, we will just
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# tickets.
|
||||
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.jira_base, self.jira_project = extract_jira_project(jira_project_url)
|
||||
self.jira_client: JIRA | None = None
|
||||
self._comment_email_blacklist = comment_email_blacklist or []
|
||||
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
|
||||
@property
|
||||
def comment_email_blacklist(self) -> tuple:
|
||||
return tuple(email.strip() for email in self._comment_email_blacklist)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
api_token = credentials["jira_api_token"]
|
||||
# if user provide an email we assume it's cloud
|
||||
if "jira_user_email" in credentials:
|
||||
email = credentials["jira_user_email"]
|
||||
self.jira_client = JIRA(
|
||||
basic_auth=(email, api_token),
|
||||
server=self.jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
else:
|
||||
self.jira_client = JIRA(
|
||||
token_auth=api_token,
|
||||
server=self.jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=f"project = {quoted_project}",
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
)
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
start_ind += fetched_batch_size
|
||||
if fetched_batch_size < self.batch_size:
|
||||
break
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
end_date_str = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
jql = (
|
||||
f"project = {quoted_project} AND "
|
||||
f"updated >= '{start_date_str}' AND "
|
||||
f"updated <= '{end_date_str}'"
|
||||
)
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=jql,
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
)
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
start_ind += fetched_batch_size
|
||||
if fetched_batch_size < self.batch_size:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = JiraConnector(
|
||||
os.environ["JIRA_PROJECT_URL"], comment_email_blacklist=[]
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"jira_user_email": os.environ["JIRA_USER_EMAIL"],
|
||||
"jira_api_token": os.environ["JIRA_API_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
@@ -1,92 +0,0 @@
|
||||
"""Module with custom fields processing functions"""
|
||||
from typing import Any
|
||||
from typing import List
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import CustomFieldOption
|
||||
from jira.resources import Issue
|
||||
from jira.resources import User
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CustomFieldExtractor:
|
||||
@staticmethod
|
||||
def _process_custom_field_value(value: Any) -> str:
|
||||
"""
|
||||
Process a custom field value to a string
|
||||
"""
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
elif isinstance(value, CustomFieldOption):
|
||||
return value.value
|
||||
elif isinstance(value, User):
|
||||
return value.displayName
|
||||
elif isinstance(value, List):
|
||||
return " ".join(
|
||||
[CustomFieldExtractor._process_custom_field_value(v) for v in value]
|
||||
)
|
||||
else:
|
||||
return str(value)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing custom field value {value}: {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def get_issue_custom_fields(
|
||||
jira: Issue, custom_fields: dict, max_value_length: int = 250
|
||||
) -> dict:
|
||||
"""
|
||||
Process all custom fields of an issue to a dictionary of strings
|
||||
:param jira: jira_issue, bug or similar
|
||||
:param custom_fields: custom fields dictionary
|
||||
:param max_value_length: maximum length of the value to be processed, if exceeded, it will be truncated
|
||||
"""
|
||||
|
||||
issue_custom_fields = {
|
||||
custom_fields[key]: value
|
||||
for key, value in jira.fields.__dict__.items()
|
||||
if value and key in custom_fields.keys()
|
||||
}
|
||||
|
||||
processed_fields = {}
|
||||
|
||||
if issue_custom_fields:
|
||||
for key, value in issue_custom_fields.items():
|
||||
processed = CustomFieldExtractor._process_custom_field_value(value)
|
||||
# We need max length parameter, because there are some plugins that often has very long description
|
||||
# and there is just a technical information so we just avoid long values
|
||||
if len(processed) < max_value_length:
|
||||
processed_fields[key] = processed
|
||||
|
||||
return processed_fields
|
||||
|
||||
@staticmethod
|
||||
def get_all_custom_fields(jira_client: JIRA) -> dict:
|
||||
"""Get all custom fields from Jira"""
|
||||
fields = jira_client.fields()
|
||||
fields_dct = {
|
||||
field["id"]: field["name"] for field in fields if field["custom"] is True
|
||||
}
|
||||
return fields_dct
|
||||
|
||||
|
||||
class CommonFieldExtractor:
|
||||
@staticmethod
|
||||
def get_issue_common_fields(jira: Issue) -> dict:
|
||||
return {
|
||||
"Priority": jira.fields.priority.name if jira.fields.priority else None,
|
||||
"Reporter": jira.fields.reporter.displayName
|
||||
if jira.fields.reporter
|
||||
else None,
|
||||
"Assignee": jira.fields.assignee.displayName
|
||||
if jira.fields.assignee
|
||||
else None,
|
||||
"Status": jira.fields.status.name if jira.fields.status else None,
|
||||
"Resolution": jira.fields.resolution.name
|
||||
if jira.fields.resolution
|
||||
else None,
|
||||
}
|
||||
@@ -1,333 +0,0 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_drive.doc_conversion import (
|
||||
convert_drive_item_to_document,
|
||||
)
|
||||
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from danswer.connectors.google_drive.file_retrieval import get_files_in_my_drive
|
||||
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from danswer.connectors.google_drive.models import GoogleDriveFileType
|
||||
from danswer.connectors.google_utils.google_auth import get_google_creds
|
||||
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.google_utils.resources import get_admin_service
|
||||
from danswer.connectors.google_utils.resources import get_drive_service
|
||||
from danswer.connectors.google_utils.resources import get_google_docs_service
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
|
||||
from danswer.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
|
||||
from danswer.connectors.google_utils.shared_constants import SCOPE_DOC_URL
|
||||
from danswer.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
|
||||
from danswer.connectors.google_utils.shared_constants import USER_FIELDS
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
|
||||
if not string:
|
||||
return []
|
||||
return [s.strip() for s in string.split(",") if s.strip()]
|
||||
|
||||
|
||||
def _extract_ids_from_urls(urls: list[str]) -> list[str]:
|
||||
return [url.split("/")[-1] for url in urls]
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
include_shared_drives: bool = True,
|
||||
shared_drive_urls: str | None = None,
|
||||
include_my_drives: bool = True,
|
||||
my_drive_emails: str | None = None,
|
||||
shared_folder_urls: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
# OLD PARAMETERS
|
||||
folder_paths: list[str] | None = None,
|
||||
include_shared: bool | None = None,
|
||||
follow_shortcuts: bool | None = None,
|
||||
only_org_public: bool | None = None,
|
||||
continue_on_failure: bool | None = None,
|
||||
) -> None:
|
||||
# Check for old input parameters
|
||||
if (
|
||||
folder_paths is not None
|
||||
or include_shared is not None
|
||||
or follow_shortcuts is not None
|
||||
or only_org_public is not None
|
||||
or continue_on_failure is not None
|
||||
):
|
||||
logger.exception(
|
||||
"Google Drive connector received old input parameters. "
|
||||
"Please visit the docs for help with the new setup: "
|
||||
f"{SCOPE_DOC_URL}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Google Drive connector received old input parameters. "
|
||||
"Please visit the docs for help with the new setup: "
|
||||
f"{SCOPE_DOC_URL}"
|
||||
)
|
||||
|
||||
if (
|
||||
not include_shared_drives
|
||||
and not include_my_drives
|
||||
and not shared_folder_urls
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of include_shared_drives, include_my_drives,"
|
||||
" or shared_folder_urls must be true"
|
||||
)
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.include_shared_drives = include_shared_drives
|
||||
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
|
||||
self.shared_drive_ids = _extract_ids_from_urls(shared_drive_url_list)
|
||||
|
||||
self.include_my_drives = include_my_drives
|
||||
self.my_drive_emails = _extract_str_list_from_comma_str(my_drive_emails)
|
||||
|
||||
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
|
||||
self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list)
|
||||
|
||||
self._primary_admin_email: str | None = None
|
||||
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
self._TRAVERSED_PARENT_IDS: set[str] = set()
|
||||
|
||||
@property
|
||||
def primary_admin_email(self) -> str:
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._primary_admin_email
|
||||
|
||||
@property
|
||||
def google_domain(self) -> str:
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._primary_admin_email.split("@")[-1]
|
||||
|
||||
@property
|
||||
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
||||
if self._creds is None:
|
||||
raise RuntimeError(
|
||||
"Creds missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._creds
|
||||
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
self._TRAVERSED_PARENT_IDS.add(folder_id)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
self._primary_admin_email = primary_admin_email
|
||||
|
||||
self._creds, new_creds_dict = get_google_creds(
|
||||
credentials=credentials,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
return new_creds_dict
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
admin_service = get_admin_service(
|
||||
creds=self.creds,
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
def _fetch_drive_items(
|
||||
self,
|
||||
is_slim: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
primary_drive_service = get_drive_service(
|
||||
creds=self.creds,
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
|
||||
if self.include_shared_drives:
|
||||
shared_drive_urls = self.shared_drive_ids
|
||||
if not shared_drive_urls:
|
||||
# if no parent ids are specified, get all shared drives using the admin account
|
||||
for drive in execute_paginated_retrieval(
|
||||
retrieval_function=primary_drive_service.drives().list,
|
||||
list_key="drives",
|
||||
useDomainAdminAccess=True,
|
||||
fields="drives(id)",
|
||||
):
|
||||
shared_drive_urls.append(drive["id"])
|
||||
|
||||
# For each shared drive, retrieve all files
|
||||
for shared_drive_id in shared_drive_urls:
|
||||
for file in get_files_in_shared_drive(
|
||||
service=primary_drive_service,
|
||||
drive_id=shared_drive_id,
|
||||
is_slim=is_slim,
|
||||
cache_folders=bool(self.shared_folder_ids),
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
yield file
|
||||
|
||||
if self.shared_folder_ids:
|
||||
# Crawl all the shared parent ids for files
|
||||
for folder_id in self.shared_folder_ids:
|
||||
yield from crawl_folders_for_files(
|
||||
service=primary_drive_service,
|
||||
parent_id=folder_id,
|
||||
personal_drive=False,
|
||||
traversed_parent_ids=self._TRAVERSED_PARENT_IDS,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
all_user_emails = []
|
||||
# get all personal docs from each users' personal drive
|
||||
if self.include_my_drives:
|
||||
if isinstance(self.creds, ServiceAccountCredentials):
|
||||
all_user_emails = self.my_drive_emails or []
|
||||
|
||||
# If using service account and no emails specified, fetch all users
|
||||
if not all_user_emails:
|
||||
all_user_emails = self._get_all_user_emails()
|
||||
|
||||
elif self.primary_admin_email:
|
||||
# If using OAuth, only fetch the primary admin email
|
||||
all_user_emails = [self.primary_admin_email]
|
||||
|
||||
for email in all_user_emails:
|
||||
logger.info(f"Fetching personal files for user: {email}")
|
||||
user_drive_service = get_drive_service(self.creds, user_email=email)
|
||||
|
||||
yield from get_files_in_my_drive(
|
||||
service=user_drive_service,
|
||||
email=email,
|
||||
is_slim=is_slim,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _extract_docs_from_google_drive(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
for file in self._fetch_drive_items(
|
||||
is_slim=False,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
user_email = (
|
||||
file.get("owners", [{}])[0].get("emailAddress")
|
||||
or self.primary_admin_email
|
||||
)
|
||||
user_drive_service = get_drive_service(self.creds, user_email=user_email)
|
||||
docs_service = get_google_docs_service(self.creds, user_email=user_email)
|
||||
if doc := convert_drive_item_to_document(
|
||||
file=file,
|
||||
drive_service=user_drive_service,
|
||||
docs_service=docs_service,
|
||||
):
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive()
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def _extract_slim_docs_from_google_drive(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_batch = []
|
||||
for file in self._fetch_drive_items(
|
||||
is_slim=True,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
slim_batch.append(
|
||||
SlimDocument(
|
||||
id=file["webViewLink"],
|
||||
perm_sync_data={
|
||||
"doc_id": file.get("id"),
|
||||
"permissions": file.get("permissions", []),
|
||||
"permission_ids": file.get("permissionIds", []),
|
||||
"name": file.get("name"),
|
||||
"owner_email": file.get("owners", [{}])[0].get("emailAddress"),
|
||||
},
|
||||
)
|
||||
)
|
||||
if len(slim_batch) >= SLIM_BATCH_SIZE:
|
||||
yield slim_batch
|
||||
slim_batch = []
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
yield from self._extract_slim_docs_from_google_drive(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
@@ -1,107 +0,0 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
GOOGLE_SCOPES,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_google_oauth_creds(
|
||||
token_json_str: str, source: DocumentSource
|
||||
) -> OAuthCredentials | None:
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(
|
||||
info=creds_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
if creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logger.notice("Refreshed Google Drive tokens.")
|
||||
return creds
|
||||
except Exception:
|
||||
logger.exception("Failed to refresh google drive access token due to:")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_google_creds(
|
||||
credentials: dict[str, str],
|
||||
source: DocumentSource,
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going thorough
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
# OAUTH
|
||||
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=access_token_json_str, source=source
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
|
||||
],
|
||||
}
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
# SERVICE ACCOUNT
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
|
||||
service_creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=GOOGLE_SCOPES[source]
|
||||
)
|
||||
|
||||
if not service_creds.valid or not service_creds.expired:
|
||||
service_creds.refresh(Request())
|
||||
|
||||
if not service_creds.valid:
|
||||
raise PermissionError(
|
||||
f"Unable to access {source} - service account credentials are invalid."
|
||||
)
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = (
|
||||
oauth_creds or service_creds
|
||||
)
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
f"Unable to access {source} - unknown credential structure."
|
||||
)
|
||||
|
||||
return creds, new_creds_dict
|
||||
@@ -1,140 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.slack.connector import filter_channels
|
||||
from danswer.connectors.slack.utils import get_message_link
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_event_time(event: dict[str, Any]) -> datetime | None:
|
||||
ts = event.get("ts")
|
||||
if not ts:
|
||||
return None
|
||||
return datetime.fromtimestamp(float(ts), tz=timezone.utc)
|
||||
|
||||
|
||||
class SlackLoadConnector(LoadConnector):
|
||||
# WARNING: DEPRECATED, DO NOT USE
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
export_path_str: str,
|
||||
channels: list[str] | None = None,
|
||||
# if specified, will treat the specified channel strings as
|
||||
# regexes, and will only index channels that fully match the regexes
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.workspace = workspace
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.export_path_str = export_path_str
|
||||
self.batch_size = batch_size
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if credentials:
|
||||
logger.warning("Unexpected credentials provided for Slack Load Connector")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _process_batch_event(
|
||||
slack_event: dict[str, Any],
|
||||
channel: dict[str, Any],
|
||||
matching_doc: Document | None,
|
||||
workspace: str,
|
||||
) -> Document | None:
|
||||
if (
|
||||
slack_event["type"] == "message"
|
||||
and slack_event.get("subtype") != "channel_join"
|
||||
):
|
||||
if matching_doc:
|
||||
return Document(
|
||||
id=matching_doc.id,
|
||||
sections=matching_doc.sections
|
||||
+ [
|
||||
Section(
|
||||
link=get_message_link(
|
||||
event=slack_event,
|
||||
workspace=workspace,
|
||||
channel_id=channel["id"],
|
||||
),
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=matching_doc.source,
|
||||
semantic_identifier=matching_doc.semantic_identifier,
|
||||
title="", # slack docs don't really have a "title"
|
||||
doc_updated_at=get_event_time(slack_event),
|
||||
metadata=matching_doc.metadata,
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=slack_event["ts"],
|
||||
sections=[
|
||||
Section(
|
||||
link=get_message_link(
|
||||
event=slack_event,
|
||||
workspace=workspace,
|
||||
channel_id=channel["id"],
|
||||
),
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=DocumentSource.SLACK,
|
||||
semantic_identifier=channel["name"],
|
||||
title="", # slack docs don't really have a "title"
|
||||
doc_updated_at=get_event_time(slack_event),
|
||||
metadata={},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
export_path = Path(self.export_path_str)
|
||||
|
||||
with open(export_path / "channels.json") as f:
|
||||
all_channels = json.load(f)
|
||||
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, self.channels, self.channel_regex_enabled
|
||||
)
|
||||
|
||||
document_batch: dict[str, Document] = {}
|
||||
for channel_info in filtered_channels:
|
||||
channel_dir_path = export_path / cast(str, channel_info["name"])
|
||||
channel_file_paths = [
|
||||
channel_dir_path / file_name
|
||||
for file_name in os.listdir(channel_dir_path)
|
||||
]
|
||||
for path in channel_file_paths:
|
||||
with open(path) as f:
|
||||
events = cast(list[dict[str, Any]], json.load(f))
|
||||
for slack_event in events:
|
||||
doc = self._process_batch_event(
|
||||
slack_event=slack_event,
|
||||
channel=channel_info,
|
||||
matching_doc=document_batch.get(
|
||||
slack_event.get("thread_ts", "")
|
||||
),
|
||||
workspace=self.workspace,
|
||||
)
|
||||
if doc:
|
||||
document_batch[doc.id] = doc
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield list(document_batch.values())
|
||||
|
||||
yield list(document_batch.values())
|
||||
@@ -1,50 +0,0 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.slack_bot_config import fetch_slack_bot_configs
|
||||
|
||||
|
||||
VALID_SLACK_FILTERS = [
|
||||
"answerable_prefilter",
|
||||
"well_answered_postfilter",
|
||||
"questionmark_prefilter",
|
||||
]
|
||||
|
||||
|
||||
def get_slack_bot_config_for_channel(
|
||||
channel_name: str | None, db_session: Session
|
||||
) -> SlackBotConfig | None:
|
||||
if not channel_name:
|
||||
return None
|
||||
|
||||
slack_bot_configs = fetch_slack_bot_configs(db_session=db_session)
|
||||
for config in slack_bot_configs:
|
||||
if channel_name in config.channel_config["channel_names"]:
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_channel_names(
|
||||
channel_names: list[str],
|
||||
current_slack_bot_config_id: int | None,
|
||||
db_session: Session,
|
||||
) -> list[str]:
|
||||
"""Make sure that these channel_names don't exist in other slack bot configs.
|
||||
Returns a list of cleaned up channel names (e.g. '#' removed if present)"""
|
||||
slack_bot_configs = fetch_slack_bot_configs(db_session=db_session)
|
||||
cleaned_channel_names = [
|
||||
channel_name.lstrip("#").lower() for channel_name in channel_names
|
||||
]
|
||||
for slack_bot_config in slack_bot_configs:
|
||||
if slack_bot_config.id == current_slack_bot_config_id:
|
||||
continue
|
||||
|
||||
for channel_name in cleaned_channel_names:
|
||||
if channel_name in slack_bot_config.channel_config["channel_names"]:
|
||||
raise ValueError(
|
||||
f"Channel name '{channel_name}' already exists in "
|
||||
"another slack bot config"
|
||||
)
|
||||
|
||||
return cleaned_channel_names
|
||||
@@ -1,499 +0,0 @@
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
||||
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_sources_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.formatting import format_slack_message
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> RT:
|
||||
if not srl.is_available():
|
||||
func_randid, position = srl.init_waiter()
|
||||
srl.notify(client, channel, position, thread_ts)
|
||||
while not srl.is_available():
|
||||
srl.waiter(func_randid)
|
||||
srl.acquire_slot()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def handle_regular_answer(
|
||||
message_info: SlackMessageInfo,
|
||||
slack_bot_config: SlackBotConfig | None,
|
||||
receiver_ids: list[str] | None,
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
logger: DanswerLoggingAdapter,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
|
||||
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
) -> bool:
|
||||
channel_conf = slack_bot_config.channel_config if slack_bot_config else None
|
||||
|
||||
messages = message_info.thread_messages
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
user = None
|
||||
if message_info.is_bot_dm:
|
||||
if message_info.email:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = slack_bot_config.persona if slack_bot_config else None
|
||||
prompt = None
|
||||
if persona:
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
bypass_acl = False
|
||||
if (
|
||||
slack_bot_config
|
||||
and slack_bot_config.persona
|
||||
and slack_bot_config.persona.document_sets
|
||||
):
|
||||
# For Slack channels, use the full document set, admin will be warned when configuring it
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
|
||||
# figure out if we want to use citations or quotes
|
||||
use_citations = (
|
||||
not DANSWER_BOT_USE_QUOTES
|
||||
if slack_bot_config is None
|
||||
else slack_bot_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
# if the message is not "/danswer" command, then it should have a message ts to respond to
|
||||
raise RuntimeError(
|
||||
"No message timestamp to respond to in `handle_message`. This should never happen."
|
||||
)
|
||||
|
||||
@retry(
|
||||
tries=num_retries,
|
||||
delay=0.25,
|
||||
backoff=2,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if len(new_message_request.messages) > 1:
|
||||
if new_message_request.persona_config:
|
||||
raise RuntimeError("Slack bot does not support persona config")
|
||||
elif new_message_request.persona_id is not None:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No persona id provided, this should never happen."
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return None
|
||||
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
# it allows the slack flow to extract out filters from the user query
|
||||
filters = BaseFilters(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
)
|
||||
|
||||
# Default True because no other ways to apply filters in Slack (no nice UI)
|
||||
# Commenting this out because this is only available to the slackbot for now
|
||||
# later we plan to implement this at the persona level where this will get
|
||||
# commented back in
|
||||
# auto_detect_filters = (
|
||||
# persona.llm_filter_extraction if persona is not None else True
|
||||
# )
|
||||
auto_detect_filters = (
|
||||
slack_bot_config.enable_auto_filters
|
||||
if slack_bot_config is not None
|
||||
else False
|
||||
)
|
||||
retrieval_details = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# Always apply reranking settings if it exists, this is the non-streaming flow
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
saved_search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
multilingual_query_expansion=saved_search_settings.multilingual_expansion
|
||||
if saved_search_settings
|
||||
else None,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
|
||||
if saved_search_settings
|
||||
else None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
f"in {num_retries} attempts"
|
||||
)
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text=f"Encountered exception when trying to answer: \n\n```{e}```",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
# In case of failures, don't keep the reaction there permanently
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
# Edge case handling, for tracking down the Slack usage issue
|
||||
if answer is None:
|
||||
assert DISABLE_GENERATIVE_AI is True
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=[
|
||||
SectionBlock(
|
||||
text="Danswer is down for maintenance.\nWe're working hard on recharging the AI!"
|
||||
)
|
||||
],
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if receiver_ids:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to process message - could not respond in slack in {num_retries} attempts"
|
||||
)
|
||||
return True
|
||||
|
||||
# Got an answer at this point, can remove reaction and give results
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if answer.answer_valid is False:
|
||||
logger.notice(
|
||||
"Answer was evaluated to be invalid, throwing it away without responding."
|
||||
)
|
||||
update_emote_react(
|
||||
emoji=DANSWER_FOLLOWUP_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=False,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if answer.answer:
|
||||
logger.debug(answer.answer)
|
||||
return True
|
||||
|
||||
retrieval_info = answer.docs
|
||||
if not retrieval_info:
|
||||
# This should not happen, even with no docs retrieved, there is still info returned
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text="Found no documents when trying to answer. Did you index any documents?",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
return True
|
||||
|
||||
if not answer.answer and disable_docs_only_answer:
|
||||
logger.notice(
|
||||
"Unable to find answer - not responding since the "
|
||||
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
|
||||
)
|
||||
return True
|
||||
|
||||
only_respond_with_citations_or_quotes = (
|
||||
channel_conf
|
||||
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
|
||||
)
|
||||
has_citations_or_quotes = bool(answer.citations or answer.quotes)
|
||||
if (
|
||||
only_respond_with_citations_or_quotes
|
||||
and not has_citations_or_quotes
|
||||
and not message_info.bypass_filters
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text="Found no citations or quotes when trying to answer.",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
return True
|
||||
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
|
||||
answer_blocks = build_qa_response_blocks(
|
||||
message_id=answer.chat_message_id,
|
||||
answer=formatted_answer,
|
||||
quotes=answer.quotes.quotes if answer.quotes else None,
|
||||
source_filters=retrieval_info.applied_source_filters,
|
||||
time_cutoff=retrieval_info.applied_time_cutoff,
|
||||
favor_recent=retrieval_info.recency_bias_multiplier > 1,
|
||||
# currently Personas don't support quotes
|
||||
# if citations are enabled, also don't use quotes
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
llm_doc_inds = answer.llm_selected_doc_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
|
||||
document_blocks = []
|
||||
citations_block = []
|
||||
# if citations are enabled, only show cited documents
|
||||
if use_citations:
|
||||
citations = answer.citations or []
|
||||
cited_docs = []
|
||||
for citation in citations:
|
||||
matching_doc = next(
|
||||
(d for d in top_docs if d.document_id == citation.document_id),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
cited_docs.append((citation.citation_num, matching_doc))
|
||||
|
||||
cited_docs.sort()
|
||||
citations_block = build_sources_blocks(cited_documents=cited_docs)
|
||||
elif priority_ordered_docs:
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
document_blocks = [DividerBlock()] + document_blocks
|
||||
|
||||
all_blocks = (
|
||||
restate_question_block + answer_blocks + citations_block + document_blocks
|
||||
)
|
||||
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
# if there is no message_ts_to_respond_to, and we have made it this far, then this is a /danswer message
|
||||
# so we shouldn't send_team_member_message
|
||||
if receiver_ids and message_ts_to_respond_to is not None:
|
||||
send_team_member_message(
|
||||
client=client,
|
||||
channel=channel,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to process message - could not respond in slack in {num_retries} attempts"
|
||||
)
|
||||
return True
|
||||
@@ -1,19 +0,0 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
|
||||
|
||||
def send_team_member_message(
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
thread_ts: str,
|
||||
) -> None:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=thread_ts,
|
||||
)
|
||||
@@ -1,58 +0,0 @@
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
def source_to_github_img_link(source: DocumentSource) -> str | None:
|
||||
# TODO: store these images somewhere better
|
||||
if source == DocumentSource.WEB.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Web.png"
|
||||
if source == DocumentSource.FILE.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
|
||||
if source == DocumentSource.GOOGLE_SITES.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/GoogleSites.png"
|
||||
if source == DocumentSource.SLACK.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Slack.png"
|
||||
if source == DocumentSource.GMAIL.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Gmail.png"
|
||||
if source == DocumentSource.GOOGLE_DRIVE.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/GoogleDrive.png"
|
||||
if source == DocumentSource.GITHUB.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Github.png"
|
||||
if source == DocumentSource.GITLAB.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Gitlab.png"
|
||||
if source == DocumentSource.CONFLUENCE.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Confluence.png"
|
||||
if source == DocumentSource.JIRA.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Jira.png"
|
||||
if source == DocumentSource.NOTION.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Notion.png"
|
||||
if source == DocumentSource.ZENDESK.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Zendesk.png"
|
||||
if source == DocumentSource.GONG.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Gong.png"
|
||||
if source == DocumentSource.LINEAR.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Linear.png"
|
||||
if source == DocumentSource.PRODUCTBOARD.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Productboard.webp"
|
||||
if source == DocumentSource.SLAB.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/SlabLogo.png"
|
||||
if source == DocumentSource.ZULIP.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Zulip.png"
|
||||
if source == DocumentSource.GURU.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Guru.png"
|
||||
if source == DocumentSource.HUBSPOT.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/HubSpot.png"
|
||||
if source == DocumentSource.DOCUMENT360.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Document360.png"
|
||||
if source == DocumentSource.BOOKSTACK.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Bookstack.png"
|
||||
if source == DocumentSource.LOOPIO.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Loopio.png"
|
||||
if source == DocumentSource.SHAREPOINT.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Sharepoint.png"
|
||||
if source == DocumentSource.REQUESTTRACKER.value:
|
||||
# just use file icon for now
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
|
||||
if source == DocumentSource.INGESTION_API.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
|
||||
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
|
||||
@@ -1,570 +0,0 @@
|
||||
import time
|
||||
from threading import Event
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import handle_doc_feedback_button
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import handle_followup_button
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import (
|
||||
handle_followup_resolved_button,
|
||||
)
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import (
|
||||
handle_generate_answer_button,
|
||||
)
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import handle_slack_feedback
|
||||
from danswer.danswerbot.slack.handlers.handle_message import handle_message
|
||||
from danswer.danswerbot.slack.handlers.handle_message import (
|
||||
remove_scheduled_feedback_reminder,
|
||||
)
|
||||
from danswer.danswerbot.slack.handlers.handle_message import schedule_feedback_reminder
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.tokens import fetch_tokens
|
||||
from danswer.danswerbot.slack.utils import check_message_limit
|
||||
from danswer.danswerbot.slack.utils import decompose_action_id
|
||||
from danswer.danswerbot.slack.utils import get_channel_name_from_id
|
||||
from danswer.danswerbot.slack.utils import get_danswer_bot_app_id
|
||||
from danswer.danswerbot.slack.utils import read_slack_thread
|
||||
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
|
||||
from danswer.danswerbot.slack.utils import rephrase_slack_message
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import TenantSocketModeClient
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
|
||||
# to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down
|
||||
# the cause.
|
||||
_SLACK_GREETINGS_TO_IGNORE = {
|
||||
"Welcome back!",
|
||||
"It's going to be a great day.",
|
||||
"Salutations!",
|
||||
"Greetings!",
|
||||
"Feeling great!",
|
||||
"Hi there",
|
||||
":wave:",
|
||||
}
|
||||
|
||||
# this is always (currently) the user id of Slack's official slackbot
|
||||
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
|
||||
|
||||
|
||||
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
|
||||
"""True to keep going, False to ignore this Slack request"""
|
||||
if req.type == "events_api":
|
||||
# Verify channel is valid
|
||||
event = cast(dict[str, Any], req.payload.get("event", {}))
|
||||
msg = cast(str | None, event.get("text"))
|
||||
channel = cast(str | None, event.get("channel"))
|
||||
channel_specific_logger = setup_logger(extra={SLACK_CHANNEL_ID: channel})
|
||||
|
||||
# This should never happen, but we can't continue without a channel since
|
||||
# we can't send a response without it
|
||||
if not channel:
|
||||
channel_specific_logger.warning("Found message without channel - skipping")
|
||||
return False
|
||||
|
||||
if not msg:
|
||||
channel_specific_logger.warning(
|
||||
"Cannot respond to empty message - skipping"
|
||||
)
|
||||
return False
|
||||
|
||||
if (
|
||||
req.payload.setdefault("event", {}).get("user", "")
|
||||
== _OFFICIAL_SLACKBOT_USER_ID
|
||||
):
|
||||
channel_specific_logger.info(
|
||||
"Ignoring messages from Slack's official Slackbot"
|
||||
)
|
||||
return False
|
||||
|
||||
if (
|
||||
msg in _SLACK_GREETINGS_TO_IGNORE
|
||||
or remove_danswer_bot_tag(msg, client=client.web_client)
|
||||
in _SLACK_GREETINGS_TO_IGNORE
|
||||
):
|
||||
channel_specific_logger.error(
|
||||
f"Ignoring weird Slack greeting message: '{msg}'"
|
||||
)
|
||||
channel_specific_logger.error(
|
||||
f"Weird Slack greeting message payload: '{req.payload}'"
|
||||
)
|
||||
return False
|
||||
|
||||
# Ensure that the message is a new message of expected type
|
||||
event_type = event.get("type")
|
||||
if event_type not in ["app_mention", "message"]:
|
||||
channel_specific_logger.info(
|
||||
f"Ignoring non-message event of type '{event_type}' for channel '{channel}'"
|
||||
)
|
||||
return False
|
||||
|
||||
bot_tag_id = get_danswer_bot_app_id(client.web_client)
|
||||
if event_type == "message":
|
||||
is_dm = event.get("channel_type") == "im"
|
||||
is_tagged = bot_tag_id and bot_tag_id in msg
|
||||
is_danswer_bot_msg = bot_tag_id and bot_tag_id in event.get("user", "")
|
||||
|
||||
# DanswerBot should never respond to itself
|
||||
if is_danswer_bot_msg:
|
||||
logger.info("Ignoring message from DanswerBot")
|
||||
return False
|
||||
|
||||
# DMs with the bot don't pick up the @DanswerBot so we have to keep the
|
||||
# caught events_api
|
||||
if is_tagged and not is_dm:
|
||||
# Let the tag flow handle this case, don't reply twice
|
||||
return False
|
||||
|
||||
if event.get("bot_profile"):
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
channel_name=channel_name, db_session=db_session
|
||||
)
|
||||
# If DanswerBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
|
||||
if (not bot_tag_id or bot_tag_id not in msg) and (
|
||||
not slack_bot_config
|
||||
or not slack_bot_config.channel_config.get("respond_to_bots")
|
||||
):
|
||||
channel_specific_logger.info("Ignoring message from bot")
|
||||
return False
|
||||
|
||||
# Ignore things like channel_join, channel_leave, etc.
|
||||
# NOTE: "file_share" is just a message with a file attachment, so we
|
||||
# should not ignore it
|
||||
message_subtype = event.get("subtype")
|
||||
if message_subtype not in [None, "file_share"]:
|
||||
channel_specific_logger.info(
|
||||
f"Ignoring message with subtype '{message_subtype}' since is is a special message type"
|
||||
)
|
||||
return False
|
||||
|
||||
message_ts = event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
# Pick the root of the thread (if a thread exists)
|
||||
# Can respond in thread if it's an "im" directly to Danswer or @DanswerBot is tagged
|
||||
if (
|
||||
thread_ts
|
||||
and message_ts != thread_ts
|
||||
and event_type != "app_mention"
|
||||
and event.get("channel_type") != "im"
|
||||
):
|
||||
channel_specific_logger.debug(
|
||||
"Skipping message since it is not the root of a thread"
|
||||
)
|
||||
return False
|
||||
|
||||
msg = cast(str, event.get("text", ""))
|
||||
if not msg:
|
||||
channel_specific_logger.error("Unable to process empty message")
|
||||
return False
|
||||
|
||||
if req.type == "slash_commands":
|
||||
# Verify that there's an associated channel
|
||||
channel = req.payload.get("channel_id")
|
||||
channel_specific_logger = setup_logger(extra={SLACK_CHANNEL_ID: channel})
|
||||
|
||||
if not channel:
|
||||
channel_specific_logger.error(
|
||||
"Received DanswerBot command without channel - skipping"
|
||||
)
|
||||
return False
|
||||
|
||||
sender = req.payload.get("user_id")
|
||||
if not sender:
|
||||
channel_specific_logger.error(
|
||||
"Cannot respond to DanswerBot command without sender to respond to."
|
||||
)
|
||||
return False
|
||||
|
||||
if not check_message_limit():
|
||||
return False
|
||||
|
||||
logger.debug(f"Handling Slack request with Payload: '{req.payload}'")
|
||||
return True
|
||||
|
||||
|
||||
def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
if actions := req.payload.get("actions"):
|
||||
action = cast(dict[str, Any], actions[0])
|
||||
feedback_type = cast(str, action.get("action_id"))
|
||||
feedback_msg_reminder = cast(str, action.get("value"))
|
||||
feedback_id = cast(str, action.get("block_id"))
|
||||
channel_id = cast(str, req.payload["container"]["channel_id"])
|
||||
thread_ts = cast(str, req.payload["container"]["thread_ts"])
|
||||
else:
|
||||
logger.error("Unable to process feedback. Action not found")
|
||||
return
|
||||
|
||||
user_id = cast(str, req.payload["user"]["id"])
|
||||
|
||||
handle_slack_feedback(
|
||||
feedback_id=feedback_id,
|
||||
feedback_type=feedback_type,
|
||||
feedback_msg_reminder=feedback_msg_reminder,
|
||||
client=client.web_client,
|
||||
user_id_to_post_confirmation=user_id,
|
||||
channel_id_to_post_confirmation=channel_id,
|
||||
thread_ts_to_post_confirmation=thread_ts,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
|
||||
query_event_id, _, _ = decompose_action_id(feedback_id)
|
||||
logger.notice(f"Successfully handled QA feedback for event: {query_event_id}")
|
||||
|
||||
|
||||
def build_request_details(
|
||||
req: SocketModeRequest, client: TenantSocketModeClient
|
||||
) -> SlackMessageInfo:
|
||||
if req.type == "events_api":
|
||||
event = cast(dict[str, Any], req.payload["event"])
|
||||
msg = cast(str, event["text"])
|
||||
channel = cast(str, event["channel"])
|
||||
tagged = event.get("type") == "app_mention"
|
||||
message_ts = event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
sender = event.get("user") or None
|
||||
expert_info = expert_info_from_slack_id(
|
||||
sender, client.web_client, user_cache={}
|
||||
)
|
||||
email = expert_info.email if expert_info else None
|
||||
|
||||
msg = remove_danswer_bot_tag(msg, client=client.web_client)
|
||||
|
||||
if DANSWER_BOT_REPHRASE_MESSAGE:
|
||||
logger.notice(f"Rephrasing Slack message. Original message: {msg}")
|
||||
try:
|
||||
msg = rephrase_slack_message(msg)
|
||||
logger.notice(f"Rephrased message: {msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error while trying to rephrase the Slack message: {e}")
|
||||
else:
|
||||
logger.notice(f"Received Slack message: {msg}")
|
||||
|
||||
if tagged:
|
||||
logger.debug("User tagged DanswerBot")
|
||||
|
||||
if thread_ts != message_ts and thread_ts is not None:
|
||||
thread_messages = read_slack_thread(
|
||||
channel=channel, thread=thread_ts, client=client.web_client
|
||||
)
|
||||
else:
|
||||
thread_messages = [
|
||||
ThreadMessage(message=msg, sender=None, role=MessageType.USER)
|
||||
]
|
||||
|
||||
return SlackMessageInfo(
|
||||
thread_messages=thread_messages,
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||
thread_to_respond=cast(str, thread_ts or message_ts),
|
||||
sender=sender,
|
||||
email=email,
|
||||
bypass_filters=tagged,
|
||||
is_bot_msg=False,
|
||||
is_bot_dm=event.get("channel_type") == "im",
|
||||
)
|
||||
|
||||
elif req.type == "slash_commands":
|
||||
channel = req.payload["channel_id"]
|
||||
msg = req.payload["text"]
|
||||
sender = req.payload["user_id"]
|
||||
expert_info = expert_info_from_slack_id(
|
||||
sender, client.web_client, user_cache={}
|
||||
)
|
||||
email = expert_info.email if expert_info else None
|
||||
|
||||
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
|
||||
|
||||
return SlackMessageInfo(
|
||||
thread_messages=[single_msg],
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=None,
|
||||
thread_to_respond=None,
|
||||
sender=sender,
|
||||
email=email,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=True,
|
||||
is_bot_dm=False,
|
||||
)
|
||||
|
||||
raise RuntimeError("Programming fault, this should never happen.")
|
||||
|
||||
|
||||
def apologize_for_fail(
|
||||
details: SlackMessageInfo,
|
||||
client: TenantSocketModeClient,
|
||||
) -> None:
|
||||
respond_in_thread(
|
||||
client=client.web_client,
|
||||
channel=details.channel_to_respond,
|
||||
thread_ts=details.msg_to_respond,
|
||||
text="Sorry, we weren't able to find anything relevant :cold_sweat:",
|
||||
)
|
||||
|
||||
|
||||
def process_message(
|
||||
req: SocketModeRequest,
|
||||
client: TenantSocketModeClient,
|
||||
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
|
||||
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"Received Slack request of type: '{req.type}' for tenant, {client.tenant_id}"
|
||||
)
|
||||
|
||||
# Throw out requests that can't or shouldn't be handled
|
||||
if not prefilter_requests(req, client):
|
||||
return
|
||||
|
||||
details = build_request_details(req, client)
|
||||
channel = details.channel_to_respond
|
||||
channel_name, is_dm = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
# Set the current tenant ID at the beginning for all DB calls within this thread
|
||||
if client.tenant_id:
|
||||
logger.info(f"Setting tenant ID to {client.tenant_id}")
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
channel_name=channel_name, db_session=db_session
|
||||
)
|
||||
|
||||
# Be careful about this default, don't want to accidentally spam every channel
|
||||
# Users should be able to DM slack bot in their private channels though
|
||||
if (
|
||||
slack_bot_config is None
|
||||
and not respond_every_channel
|
||||
# Can't have configs for DMs so don't toss them out
|
||||
and not is_dm
|
||||
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
|
||||
# always respond with the default configs
|
||||
and not (details.is_bot_msg or details.bypass_filters)
|
||||
):
|
||||
return
|
||||
|
||||
follow_up = bool(
|
||||
slack_bot_config
|
||||
and slack_bot_config.channel_config
|
||||
and slack_bot_config.channel_config.get("follow_up_tags") is not None
|
||||
)
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_bot_config=slack_bot_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
finally:
|
||||
if client.tenant_id:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||||
client.send_socket_mode_response(response)
|
||||
|
||||
|
||||
def action_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
if actions := req.payload.get("actions"):
|
||||
action = cast(dict[str, Any], actions[0])
|
||||
|
||||
if action["action_id"] in [DISLIKE_BLOCK_ACTION_ID, LIKE_BLOCK_ACTION_ID]:
|
||||
# AI Answer feedback
|
||||
return process_feedback(req, client)
|
||||
elif action["action_id"] == FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID:
|
||||
# Activation of the "source feedback" button
|
||||
return handle_doc_feedback_button(req, client)
|
||||
elif action["action_id"] == FOLLOWUP_BUTTON_ACTION_ID:
|
||||
return handle_followup_button(req, client)
|
||||
elif action["action_id"] == IMMEDIATE_RESOLVED_BUTTON_ACTION_ID:
|
||||
return handle_followup_resolved_button(req, client, immediate=True)
|
||||
elif action["action_id"] == FOLLOWUP_BUTTON_RESOLVED_ACTION_ID:
|
||||
return handle_followup_resolved_button(req, client, immediate=False)
|
||||
elif action["action_id"] == GENERATE_ANSWER_BUTTON_ACTION_ID:
|
||||
return handle_generate_answer_button(req, client)
|
||||
|
||||
|
||||
def view_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
if view := req.payload.get("view"):
|
||||
if view["callback_id"] == VIEW_DOC_FEEDBACK_ID:
|
||||
return process_feedback(req, client)
|
||||
|
||||
|
||||
def process_slack_event(client: TenantSocketModeClient, req: SocketModeRequest) -> None:
|
||||
# Always respond right away, if Slack doesn't receive these frequently enough
|
||||
# it will assume the Bot is DEAD!!! :(
|
||||
acknowledge_message(req, client)
|
||||
|
||||
try:
|
||||
if req.type == "interactive":
|
||||
if req.payload.get("type") == "block_actions":
|
||||
return action_routing(req, client)
|
||||
elif req.payload.get("type") == "view_submission":
|
||||
return view_routing(req, client)
|
||||
elif req.type == "events_api" or req.type == "slash_commands":
|
||||
return process_message(req, client)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to process slack event. Error: {e}")
|
||||
logger.error(f"Slack request payload: {req.payload}")
|
||||
|
||||
|
||||
def _get_socket_client(
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str | None
|
||||
) -> TenantSocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.danswer.dev/slack_bot_setup
|
||||
return TenantSocketModeClient(
|
||||
# This app-level token will be used only for establishing a connection
|
||||
app_token=slack_bot_tokens.app_token,
|
||||
web_client=WebClient(token=slack_bot_tokens.bot_token),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None:
|
||||
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
|
||||
|
||||
# Establish a WebSocket connection to the Socket Mode servers
|
||||
logger.notice(f"Listening for messages from Slack {socket_client.tenant_id }...")
|
||||
socket_client.connect()
|
||||
|
||||
|
||||
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
|
||||
# the slack bot in your workspace, and then add the bot to any channels you want to
|
||||
# try and answer questions for. Running this file will setup Danswer to listen to all
|
||||
# messages in those channels and attempt to answer them. As of now, it will only respond
|
||||
# to messages sent directly in the channel - it will not respond to messages sent within a
|
||||
# thread.
|
||||
#
|
||||
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
|
||||
# without issue.
|
||||
if __name__ == "__main__":
|
||||
slack_bot_tokens: dict[str | None, SlackBotTokens] = {}
|
||||
socket_clients: dict[str | None, TenantSocketModeClient] = {}
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
download_nltk_data()
|
||||
|
||||
while True:
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids() # Function to retrieve all tenant IDs
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
|
||||
latest_slack_bot_tokens = fetch_tokens()
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
if (
|
||||
tenant_id not in slack_bot_tokens
|
||||
or latest_slack_bot_tokens != slack_bot_tokens[tenant_id]
|
||||
):
|
||||
if tenant_id in slack_bot_tokens:
|
||||
logger.notice(
|
||||
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
|
||||
)
|
||||
else:
|
||||
# Initial setup for this tenant
|
||||
search_settings = get_current_search_settings(
|
||||
db_session
|
||||
)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(embedding_model=embedding_model)
|
||||
|
||||
slack_bot_tokens[tenant_id] = latest_slack_bot_tokens
|
||||
|
||||
# potentially may cause a message to be dropped, but it is complicated
|
||||
# to avoid + (1) if the user is changing tokens, they are likely okay with some
|
||||
# "migration downtime" and (2) if a single message is lost it is okay
|
||||
# as this should be a very rare occurrence
|
||||
if tenant_id in socket_clients:
|
||||
socket_clients[tenant_id].close()
|
||||
|
||||
socket_client = _get_socket_client(
|
||||
latest_slack_bot_tokens, tenant_id
|
||||
)
|
||||
|
||||
# Initialize socket client for this tenant. Each tenant has its own
|
||||
# socket client, allowing for multiple concurrent connections (one
|
||||
# per tenant) with the tenant ID wrapped in the socket model client.
|
||||
# Each `connect` stores websocket connection in a separate thread.
|
||||
_initialize_socket_client(socket_client)
|
||||
|
||||
socket_clients[tenant_id] = socket_client
|
||||
|
||||
except KvKeyNotFoundError:
|
||||
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
|
||||
if tenant_id in socket_clients:
|
||||
socket_clients[tenant_id].disconnect()
|
||||
del socket_clients[tenant_id]
|
||||
del slack_bot_tokens[tenant_id]
|
||||
|
||||
# Wait before checking for updates
|
||||
Event().wait(timeout=60)
|
||||
|
||||
except Exception:
|
||||
logger.exception("An error occurred outside of main event loop")
|
||||
time.sleep(60)
|
||||
@@ -1,28 +0,0 @@
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_SLACK_BOT_TOKENS_CONFIG_KEY
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
|
||||
|
||||
def fetch_tokens() -> SlackBotTokens:
|
||||
# first check env variables
|
||||
app_token = os.environ.get("DANSWER_BOT_SLACK_APP_TOKEN")
|
||||
bot_token = os.environ.get("DANSWER_BOT_SLACK_BOT_TOKEN")
|
||||
if app_token and bot_token:
|
||||
return SlackBotTokens(app_token=app_token, bot_token=bot_token)
|
||||
|
||||
dynamic_config_store = get_kv_store()
|
||||
return SlackBotTokens(
|
||||
**cast(dict, dynamic_config_store.load(key=KV_SLACK_BOT_TOKENS_CONFIG_KEY))
|
||||
)
|
||||
|
||||
|
||||
def save_tokens(
|
||||
tokens: SlackBotTokens,
|
||||
) -> None:
|
||||
dynamic_config_store = get_kv_store()
|
||||
dynamic_config_store.store(
|
||||
key=KV_SLACK_BOT_TOKENS_CONFIG_KEY, val=dict(tokens), encrypt=True
|
||||
)
|
||||
@@ -1,202 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import InputPrompt
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.input_prompt.models import InputPromptSnapshot
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def insert_input_prompt_if_not_exists(
|
||||
user: User | None,
|
||||
input_prompt_id: int | None,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> InputPrompt:
|
||||
if input_prompt_id is not None:
|
||||
input_prompt = (
|
||||
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
|
||||
)
|
||||
else:
|
||||
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
|
||||
if user:
|
||||
query = query.filter(InputPrompt.user_id == user.id)
|
||||
else:
|
||||
query = query.filter(InputPrompt.user_id.is_(None))
|
||||
input_prompt = query.first()
|
||||
|
||||
if input_prompt is None:
|
||||
input_prompt = InputPrompt(
|
||||
id=input_prompt_id,
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=active,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def insert_input_prompt(
|
||||
prompt: str,
|
||||
content: str,
|
||||
is_public: bool,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = InputPrompt(
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=True,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user is not None else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def update_input_prompt(
|
||||
user: User | None,
|
||||
input_prompt_id: int,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You don't own this prompt")
|
||||
|
||||
input_prompt.prompt = prompt
|
||||
input_prompt.content = content
|
||||
input_prompt.active = active
|
||||
|
||||
db_session.commit()
|
||||
return input_prompt
|
||||
|
||||
|
||||
def validate_user_prompt_authorization(
|
||||
user: User | None, input_prompt: InputPrompt
|
||||
) -> bool:
|
||||
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
|
||||
|
||||
if prompt.user_id is not None:
|
||||
if user is None:
|
||||
return False
|
||||
|
||||
user_details = UserInfo.from_model(user)
|
||||
if str(user_details.id) != str(prompt.user_id):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not input_prompt.is_public:
|
||||
raise HTTPException(status_code=400, detail="This prompt is not public")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_input_prompt(
|
||||
user: User | None, input_prompt_id: int, db_session: Session
|
||||
) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if input_prompt.is_public:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete public prompts with this method"
|
||||
)
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You do not own this prompt")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_input_prompt_by_id(
|
||||
id: int, user_id: UUID | None, db_session: Session
|
||||
) -> InputPrompt:
|
||||
query = select(InputPrompt).where(InputPrompt.id == id)
|
||||
|
||||
if user_id:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
|
||||
)
|
||||
else:
|
||||
# If no user_id is provided, only fetch prompts without a user_id (aka public)
|
||||
query = query.where(InputPrompt.user_id == None) # noqa
|
||||
|
||||
result = db_session.scalar(query)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(422, "No input prompt found")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_public_input_prompts(
|
||||
db_session: Session,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt).where(InputPrompt.is_public)
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
|
||||
def fetch_input_prompts_by_user(
|
||||
db_session: Session,
|
||||
user_id: UUID | None,
|
||||
active: bool | None = None,
|
||||
include_public: bool = False,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt)
|
||||
|
||||
if user_id is not None:
|
||||
if include_public:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | InputPrompt.is_public
|
||||
)
|
||||
else:
|
||||
query = query.where(InputPrompt.user_id == user_id)
|
||||
|
||||
elif include_public:
|
||||
query = query.where(InputPrompt.is_public)
|
||||
|
||||
if active is not None:
|
||||
query = query.where(InputPrompt.active == active)
|
||||
|
||||
return list(db_session.scalars(query).all())
|
||||
@@ -1,99 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.db.models import User
|
||||
|
||||
|
||||
def list_users(
|
||||
db_session: Session, email_filter_string: str = "", user: User | None = None
|
||||
) -> Sequence[User]:
|
||||
"""List all users. No pagination as of now, as the # of users
|
||||
is assumed to be relatively small (<< 1 million)"""
|
||||
stmt = select(User)
|
||||
|
||||
if email_filter_string:
|
||||
stmt = stmt.where(User.email.ilike(f"%{email_filter_string}%")) # type: ignore
|
||||
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def get_users_by_emails(
|
||||
db_session: Session, emails: list[str]
|
||||
) -> tuple[list[User], list[str]]:
|
||||
# Use distinct to avoid duplicates
|
||||
stmt = select(User).filter(User.email.in_(emails)) # type: ignore
|
||||
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
|
||||
found_users_emails = [user.email for user in found_users]
|
||||
missing_user_emails = [email for email in emails if email not in found_users_emails]
|
||||
return found_users, missing_user_emails
|
||||
|
||||
|
||||
def get_user_by_email(email: str, db_session: Session) -> User | None:
|
||||
user = (
|
||||
db_session.query(User)
|
||||
.filter(func.lower(User.email) == func.lower(email))
|
||||
.first()
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
|
||||
user = db_session.query(User).filter(User.id == user_id).first() # type: ignore
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def _generate_non_web_user(email: str) -> User:
|
||||
fastapi_users_pw_helper = PasswordHelper()
|
||||
password = fastapi_users_pw_helper.generate()
|
||||
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||
return User(
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
has_web_login=False,
|
||||
role=UserRole.BASIC,
|
||||
)
|
||||
|
||||
|
||||
def add_non_web_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
user = _generate_non_web_user(email=email)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def add_non_web_user_if_not_exists__no_commit(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
user = _generate_non_web_user(email=email)
|
||||
db_session.add(user)
|
||||
db_session.flush() # generate id
|
||||
return user
|
||||
|
||||
|
||||
def batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session: Session, emails: list[str]
|
||||
) -> list[User]:
|
||||
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
|
||||
|
||||
new_users: list[User] = []
|
||||
for email in missing_user_emails:
|
||||
new_users.append(_generate_non_web_user(email=email))
|
||||
|
||||
db_session.add_all(new_users)
|
||||
db_session.flush() # generate ids
|
||||
|
||||
return found_users + new_users
|
||||
@@ -1,85 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
) -> InMemoryChatFile:
|
||||
file_io = get_default_file_store(db_session).read_file(
|
||||
file_descriptor["id"], mode="b"
|
||||
)
|
||||
return InMemoryChatFile(
|
||||
file_id=file_descriptor["id"],
|
||||
content=file_io.read(),
|
||||
file_type=file_descriptor["type"],
|
||||
filename=file_descriptor.get("name"),
|
||||
)
|
||||
|
||||
|
||||
def load_all_chat_files(
|
||||
chat_messages: list[ChatMessage],
|
||||
file_descriptors: list[FileDescriptor],
|
||||
db_session: Session,
|
||||
) -> list[InMemoryChatFile]:
|
||||
file_descriptors_for_history: list[FileDescriptor] = []
|
||||
for chat_message in chat_messages:
|
||||
if chat_message.files:
|
||||
file_descriptors_for_history.extend(chat_message.files)
|
||||
|
||||
files = cast(
|
||||
list[InMemoryChatFile],
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(load_chat_file, (file, db_session))
|
||||
for file in file_descriptors + file_descriptors_for_history
|
||||
]
|
||||
),
|
||||
)
|
||||
return files
|
||||
|
||||
|
||||
def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
"""NOTE: using multiple sessions here, since this is often called
|
||||
using multithreading. In practice, sharing a session has resulted in
|
||||
weird errors."""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
unique_id = str(uuid4())
|
||||
|
||||
file_io = BytesIO(response.content)
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
file_name=unique_id,
|
||||
content=file_io,
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type="image/png;base64",
|
||||
)
|
||||
return unique_id
|
||||
|
||||
|
||||
def save_files_from_urls(urls: list[str]) -> list[str]:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
|
||||
(save_file_from_url, (url, tenant_id)) for url in urls
|
||||
]
|
||||
# Must pass in tenant_id here, since this is called by multithreading
|
||||
return run_functions_tuples_in_parallel(funcs)
|
||||
@@ -1,41 +0,0 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Heartbeat(abc.ABC):
|
||||
"""Useful for any long-running work that goes through a bunch of items
|
||||
and needs to occasionally give updates on progress.
|
||||
e.g. chunking, embedding, updating vespa, etc."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IndexingHeartbeat(Heartbeat):
|
||||
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
|
||||
self.cnt = 0
|
||||
|
||||
self.index_attempt_id = index_attempt_id
|
||||
self.db_session = db_session
|
||||
self.freq = freq
|
||||
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
self.cnt += 1
|
||||
if self.cnt % self.freq == 0:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=self.db_session, index_attempt_id=self.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
index_attempt.time_updated = func.now()
|
||||
self.db_session.commit()
|
||||
else:
|
||||
logger.error("Index attempt not found, this should not happen!")
|
||||
@@ -1,84 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| DanswerQuotes
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: "ToolResponseHandler",
|
||||
answer_handler: "AnswerResponseHandler",
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
self.answer_handler = answer_handler
|
||||
self.is_cancelled = is_cancelled
|
||||
|
||||
def handle_llm_response(
|
||||
self,
|
||||
stream: Iterator[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
all_messages: list[BaseMessage] = []
|
||||
for message in stream:
|
||||
if self.is_cancelled():
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
return
|
||||
# tool handler doesn't do anything until the full message is received
|
||||
# NOTE: still need to run list() to get this to run
|
||||
list(self.tool_handler.handle_response_part(message, all_messages))
|
||||
yield from self.answer_handler.handle_response_part(message, all_messages)
|
||||
all_messages.append(message)
|
||||
|
||||
# potentially give back all info on the selected tool call + its result
|
||||
yield from self.tool_handler.handle_response_part(None, all_messages)
|
||||
yield from self.answer_handler.handle_response_part(None, all_messages)
|
||||
|
||||
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
|
||||
return self.tool_handler.next_llm_call(llm_call)
|
||||
@@ -1,163 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Prompt
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
class PreviousMessage(BaseModel):
|
||||
"""Simplified version of `ChatMessage`"""
|
||||
|
||||
message: str
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
|
||||
) -> "PreviousMessage":
|
||||
message_file_ids = (
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
)
|
||||
return cls(
|
||||
message=chat_message.message,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
files=[
|
||||
file
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
content = build_content_with_imgs(self.message, self.files)
|
||||
if self.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
elif self.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
else:
|
||||
return SystemMessage(content=content)
|
||||
|
||||
|
||||
class DocumentPruningConfig(BaseModel):
|
||||
max_chunks: int | None = None
|
||||
max_window_percentage: float | None = None
|
||||
max_tokens: int | None = None
|
||||
# different pruning behavior is expected when the
|
||||
# user manually selects documents they want to chat with
|
||||
# e.g. we don't want to truncate each document to be no more
|
||||
# than one chunk long
|
||||
is_manually_selected_docs: bool = False
|
||||
# If user specifies to include additional context Chunks for each match, then different pruning
|
||||
# is used. As many Sections as possible are included, and the last Section is truncated
|
||||
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
|
||||
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
|
||||
use_sections: bool = True
|
||||
# If using tools, then we need to consider the tool length
|
||||
tool_num_tokens: int = 0
|
||||
# If using a tool message to represent the docs, then we have to JSON serialize
|
||||
# the document content, which adds to the token count.
|
||||
using_tool_message: bool = False
|
||||
|
||||
|
||||
class ContextualPruningConfig(DocumentPruningConfig):
|
||||
num_chunk_multiple: int
|
||||
|
||||
@classmethod
|
||||
def from_doc_pruning_config(
|
||||
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
|
||||
) -> "ContextualPruningConfig":
|
||||
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
include_citations: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
||||
) -> "PromptConfig":
|
||||
override_system_prompt = (
|
||||
prompt_override.system_prompt if prompt_override else None
|
||||
)
|
||||
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
||||
|
||||
return cls(
|
||||
system_prompt=override_system_prompt or model.system_prompt,
|
||||
task_prompt=override_task_prompt or model.task_prompt,
|
||||
datetime_aware=model.datetime_aware,
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
@@ -1,97 +0,0 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
|
||||
|
||||
def _build_weak_llm_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
|
||||
as an option to use with weaker LLMs such as small version, low float precision, quantized,
|
||||
or distilled models. It only uses one context document and has very weak requirements of
|
||||
output format.
|
||||
"""
|
||||
context_block = ""
|
||||
if context_docs:
|
||||
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs[0].content)
|
||||
|
||||
prompt_str = WEAK_LLM_PROMPT.format(
|
||||
system_prompt=prompt.system_prompt,
|
||||
context_block=context_block,
|
||||
task_prompt=prompt.task_prompt,
|
||||
user_query=question,
|
||||
)
|
||||
|
||||
if prompt.datetime_aware:
|
||||
prompt_str = add_date_time_to_prompt(prompt_str=prompt_str)
|
||||
|
||||
return HumanMessage(content=prompt_str)
|
||||
|
||||
|
||||
def _build_strong_llm_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
use_language_hint = bool(get_multilingual_expansion())
|
||||
|
||||
context_block = ""
|
||||
if context_docs:
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
|
||||
|
||||
history_block = ""
|
||||
if history_str:
|
||||
history_block = HISTORY_BLOCK.format(history_str=history_str)
|
||||
|
||||
full_prompt = JSON_PROMPT.format(
|
||||
system_prompt=prompt.system_prompt,
|
||||
context_block=context_block,
|
||||
history_block=history_block,
|
||||
task_prompt=prompt.task_prompt,
|
||||
user_query=question,
|
||||
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
|
||||
).strip()
|
||||
|
||||
if prompt.datetime_aware:
|
||||
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
|
||||
|
||||
return HumanMessage(content=full_prompt)
|
||||
|
||||
|
||||
def build_quotes_user_message(
|
||||
message: HumanMessage,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
prompt_builder = (
|
||||
_build_weak_llm_quotes_prompt
|
||||
if QA_PROMPT_OVERRIDE == "weak"
|
||||
else _build_strong_llm_quotes_prompt
|
||||
)
|
||||
|
||||
query, _ = message_to_prompt_and_imgs(message)
|
||||
|
||||
return prompt_builder(
|
||||
question=query,
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
)
|
||||
@@ -1,20 +0,0 @@
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
@@ -1,150 +0,0 @@
|
||||
import os
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from copy import copy
|
||||
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
transformer_logging.set_verbosity_error()
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
|
||||
class BaseTokenizer(ABC):
|
||||
@abstractmethod
|
||||
def encode(self, string: str) -> list[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class TiktokenTokenizer(BaseTokenizer):
|
||||
_instances: dict[str, "TiktokenTokenizer"] = {}
|
||||
|
||||
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
|
||||
if encoding_name not in cls._instances:
|
||||
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
|
||||
return cls._instances[encoding_name]
|
||||
|
||||
def __init__(self, encoding_name: str = "cl100k_base"):
|
||||
if not hasattr(self, "encoder"):
|
||||
import tiktoken
|
||||
|
||||
self.encoder = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def encode(self, string: str) -> list[int]:
|
||||
# this returns no special tokens
|
||||
return self.encoder.encode_ordinary(string)
|
||||
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
return [self.encoder.decode([token]) for token in self.encode(string)]
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
return self.encoder.decode(tokens)
|
||||
|
||||
|
||||
class HuggingFaceTokenizer(BaseTokenizer):
|
||||
def __init__(self, model_name: str):
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
|
||||
self.encoder = Tokenizer.from_pretrained(model_name)
|
||||
|
||||
def encode(self, string: str) -> list[int]:
|
||||
# this returns no special tokens
|
||||
return self.encoder.encode(string, add_special_tokens=False).ids
|
||||
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
return self.encoder.encode(string, add_special_tokens=False).tokens
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
return self.encoder.decode(tokens)
|
||||
|
||||
|
||||
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
|
||||
|
||||
|
||||
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
||||
global _TOKENIZER_CACHE
|
||||
|
||||
if tokenizer_name not in _TOKENIZER_CACHE:
|
||||
if tokenizer_name == "openai":
|
||||
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
|
||||
return _TOKENIZER_CACHE[tokenizer_name]
|
||||
try:
|
||||
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
|
||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
|
||||
except Exception as primary_error:
|
||||
logger.error(
|
||||
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
|
||||
)
|
||||
logger.warning(
|
||||
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Cache this tokenizer name to the default so we don't have to try to load it again
|
||||
# and fail again
|
||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
|
||||
) from fallback_error
|
||||
|
||||
return _TOKENIZER_CACHE[tokenizer_name]
|
||||
|
||||
|
||||
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
model_name: str | None, provider_type: EmbeddingProvider | str | None
|
||||
) -> BaseTokenizer:
|
||||
# Currently all of the viable models use the same sentencepiece tokenizer
|
||||
# OpenAI uses a different one but currently it's not supported due to quality issues
|
||||
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
|
||||
# LLM tokenizers are specified by strings
|
||||
global _DEFAULT_TOKENIZER
|
||||
return _DEFAULT_TOKENIZER
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
tokens = tokenizer.encode(content)
|
||||
if len(tokens) > desired_length:
|
||||
content = tokenizer.decode(tokens[:desired_length])
|
||||
return content
|
||||
|
||||
|
||||
def tokenizer_trim_chunks(
|
||||
chunks: list[InferenceChunk],
|
||||
tokenizer: BaseTokenizer,
|
||||
max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
) -> list[InferenceChunk]:
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
|
||||
if len(new_content) != len(chunk.content):
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = new_content
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
@@ -1,436 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import reorganize_citations
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import DocumentRelevance
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import RelevanceAnalysis
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.chat import update_search_docs_table_with_relevance
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompt_by_id
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
AnswerObjectIterator = Iterator[
|
||||
QueryRephrase
|
||||
| QADocsResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| DanswerContexts
|
||||
| StreamingError
|
||||
| ChatMessageDetail
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| DocumentRelevance
|
||||
]
|
||||
|
||||
|
||||
def stream_answer_objects(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
# These need to be passed in because in Web UI one shot flow,
|
||||
# we can have much more document as there is no history.
|
||||
# For Slack flow, we need to save more tokens for the thread context
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
bypass_acl: bool = False,
|
||||
use_citations: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> AnswerObjectIterator:
|
||||
"""Streams in order:
|
||||
1. [always] Retrieved documents, stops flow if nothing is found
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed DanswerAnswerPiece and DanswerQuotes at the end
|
||||
or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
user_id = user.id if user is not None else None
|
||||
query_msg = query_req.messages[-1]
|
||||
history = query_req.messages[:-1]
|
||||
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # One shot queries don't need naming as it's never displayed
|
||||
user_id=user_id,
|
||||
persona_id=query_req.persona_id,
|
||||
one_shot=True,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
)
|
||||
|
||||
temporary_persona: Persona | None = None
|
||||
if query_req.persona_config is not None:
|
||||
new_persona = create_temporary_persona(
|
||||
db_session=db_session, persona_config=query_req.persona_config, user=user
|
||||
)
|
||||
temporary_persona = new_persona
|
||||
|
||||
persona = temporary_persona if temporary_persona else chat_session.persona
|
||||
|
||||
try:
|
||||
llm, fast_llm = get_llms_for_persona(persona=persona)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Failed to initialize LLMs for persona '{persona.name}': {str(e)}"
|
||||
)
|
||||
if "No LLM provider" in str(e):
|
||||
raise ValueError(
|
||||
"Please configure a Generative AI model to use this feature."
|
||||
) from e
|
||||
raise ValueError(
|
||||
"Failed to initialize the AI model. Please check your configuration and try again."
|
||||
) from e
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Create a chat session which will just store the root message, the query, and the AI response
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
history_str = combine_message_thread(
|
||||
messages=history,
|
||||
max_tokens=max_history_tokens,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
rephrased_query = query_req.query_override or thread_based_query_rephrase(
|
||||
user_query=query_msg.message,
|
||||
history_str=history_str,
|
||||
)
|
||||
|
||||
# Given back ahead of the documents for latency reasons
|
||||
# In chat flow it's given back along with the documents
|
||||
yield QueryRephrase(rephrased_query=rephrased_query)
|
||||
|
||||
prompt = None
|
||||
if query_req.prompt_id is not None:
|
||||
# NOTE: let the user access any prompt as long as the Persona is shared
|
||||
# with them
|
||||
prompt = get_prompt_by_id(
|
||||
prompt_id=query_req.prompt_id, user=None, db_session=db_session
|
||||
)
|
||||
if prompt is None:
|
||||
if not persona.prompts:
|
||||
raise RuntimeError(
|
||||
"Persona does not have any prompts - this should never happen"
|
||||
)
|
||||
prompt = persona.prompts[0]
|
||||
|
||||
# Create the first User query message
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=query_msg.message,
|
||||
token_count=len(llm_tokenizer.encode(query_msg.message)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
prompt_config = PromptConfig.from_model(prompt)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
|
||||
),
|
||||
max_tokens=max_document_tokens,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig() if use_citations else None,
|
||||
quotes_config=QuotesConfig() if not use_citations else None,
|
||||
document_pruning_config=document_pruning_config,
|
||||
)
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.SKIP
|
||||
if DISABLE_LLM_DOC_RELEVANCE
|
||||
else query_req.evaluation_type
|
||||
),
|
||||
persona=persona,
|
||||
retrieval_options=query_req.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_config,
|
||||
bypass_acl=bypass_acl,
|
||||
chunks_above=query_req.chunks_above,
|
||||
chunks_below=query_req.chunks_below,
|
||||
full_doc=query_req.full_doc,
|
||||
)
|
||||
|
||||
answer = Answer(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool] if search_tool else [],
|
||||
force_use_tool=(
|
||||
ForceUseTool(
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
force_use=True,
|
||||
)
|
||||
),
|
||||
# for now, don't use tool calling for this flow, as we haven't
|
||||
# tested quotes with tool calling too much yet
|
||||
skip_explicit_tool_calling=True,
|
||||
return_contexts=query_req.return_contexts,
|
||||
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
# won't be any FileChatDisplay responses since that tool is never passed in
|
||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||
# for one-shot flow, don't currently do anything with these
|
||||
if isinstance(packet, ToolResponse):
|
||||
# (likely fine that it comes after the initial creation of the search docs)
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
top_docs = chunks_or_sections_to_search_docs(
|
||||
search_response_summary.top_sections
|
||||
)
|
||||
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
deduped_docs = top_docs
|
||||
if query_req.retrieval_options.dedupe_docs:
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
initial_response = QADocsResponse(
|
||||
rephrased_query=rephrased_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=search_response_summary.predicted_flow,
|
||||
predicted_search=search_response_summary.predicted_search,
|
||||
applied_source_filters=search_response_summary.final_filters.source_type,
|
||||
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
|
||||
)
|
||||
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
document_based_response = {}
|
||||
|
||||
if packet.response is not None:
|
||||
for evaluation in packet.response:
|
||||
document_based_response[
|
||||
evaluation.document_id
|
||||
] = RelevanceAnalysis(
|
||||
relevant=evaluation.relevant, content=evaluation.content
|
||||
)
|
||||
|
||||
evaluation_response = DocumentRelevance(
|
||||
relevance_summaries=document_based_response
|
||||
)
|
||||
if reference_db_search_docs is not None:
|
||||
update_search_docs_table_with_relevance(
|
||||
db_session=db_session,
|
||||
reference_db_search_docs=reference_db_search_docs,
|
||||
relevance_summary=evaluation_response,
|
||||
)
|
||||
yield evaluation_response
|
||||
|
||||
else:
|
||||
yield packet
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
gen_ai_response_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=answer.llm_answer,
|
||||
token_count=len(llm_tokenizer.encode(answer.llm_answer)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
yield msg_detail_response
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_search_answer(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as session:
|
||||
objects = stream_answer_objects(
|
||||
query_req=query_req,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=session,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
def get_search_answer(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
db_session: Session,
|
||||
answer_generation_timeout: int = QA_TIMEOUT,
|
||||
enable_reflexion: bool = False,
|
||||
bypass_acl: bool = False,
|
||||
use_citations: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> OneShotQAResponse:
|
||||
"""Collects the streamed one shot answer responses into a single object"""
|
||||
qa_response = OneShotQAResponse()
|
||||
|
||||
results = stream_answer_objects(
|
||||
query_req=query_req,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
timeout=answer_generation_timeout,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
|
||||
answer = ""
|
||||
for packet in results:
|
||||
if isinstance(packet, QueryRephrase):
|
||||
qa_response.rephrase = packet.rephrased_query
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
qa_response.docs = packet
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
qa_response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, DanswerQuotes):
|
||||
qa_response.quotes = packet
|
||||
elif isinstance(packet, CitationInfo):
|
||||
if qa_response.citations:
|
||||
qa_response.citations.append(packet)
|
||||
else:
|
||||
qa_response.citations = [packet]
|
||||
elif isinstance(packet, DanswerContexts):
|
||||
qa_response.contexts = packet
|
||||
elif isinstance(packet, StreamingError):
|
||||
qa_response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
qa_response.chat_message_id = packet.message_id
|
||||
|
||||
if answer:
|
||||
qa_response.answer = answer
|
||||
|
||||
if enable_reflexion:
|
||||
# Because follow up messages are explicitly tagged, we don't need to verify the answer
|
||||
if len(query_req.messages) == 1:
|
||||
first_query = query_req.messages[0].message
|
||||
qa_response.answer_valid = get_answer_validity(first_query, answer)
|
||||
else:
|
||||
qa_response.answer_valid = True
|
||||
|
||||
if use_citations and qa_response.answer and qa_response.citations:
|
||||
# Reorganize citation nums to be in the same order as the answer
|
||||
qa_response.answer, qa_response.citations = reorganize_citations(
|
||||
qa_response.answer, qa_response.citations
|
||||
)
|
||||
|
||||
return qa_response
|
||||
@@ -1,118 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import ChunkContext
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
|
||||
|
||||
class QueryRephrase(BaseModel):
|
||||
rephrased_query: str
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class DocumentSetConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PersonaConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DirectQARequest(ChunkContext):
|
||||
persona_config: PersonaConfig | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None = None
|
||||
multilingual_query_expansion: list[str] | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
|
||||
|
||||
chain_of_thought: bool = False
|
||||
return_contexts: bool = False
|
||||
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# can be used if the message sent to the LLM / query should not be the same
|
||||
# will also disable Thread-based Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "DirectQARequest":
|
||||
if (self.persona_config is None) == (self.persona_id is None):
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
|
||||
if self.chain_of_thought and self.prompt_id is not None:
|
||||
raise ValueError(
|
||||
"If chain_of_thought is True, prompt_id must be None"
|
||||
"The chain of thought prompt is only for question "
|
||||
"answering and does not accept customizing."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class OneShotQAResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
rephrase: str | None = None
|
||||
quotes: DanswerQuotes | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
|
||||
chat_message_id: int | None = None
|
||||
contexts: DanswerContexts | None = None
|
||||
@@ -1,53 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
|
||||
"""Mock streaming by generating the passed in model output, character by character"""
|
||||
for token in model_out:
|
||||
yield token
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
@@ -1,134 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.input_prompt import fetch_input_prompt_by_id
|
||||
from danswer.db.input_prompt import fetch_input_prompts_by_user
|
||||
from danswer.db.input_prompt import fetch_public_input_prompts
|
||||
from danswer.db.input_prompt import insert_input_prompt
|
||||
from danswer.db.input_prompt import remove_input_prompt
|
||||
from danswer.db.input_prompt import remove_public_input_prompt
|
||||
from danswer.db.input_prompt import update_input_prompt
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.input_prompt.models import CreateInputPromptRequest
|
||||
from danswer.server.features.input_prompt.models import InputPromptSnapshot
|
||||
from danswer.server.features.input_prompt.models import UpdateInputPromptRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_router = APIRouter(prefix="/input_prompt")
|
||||
admin_router = APIRouter(prefix="/admin/input_prompt")
|
||||
|
||||
|
||||
@basic_router.get("")
|
||||
def list_input_prompts(
|
||||
user: User | None = Depends(current_user),
|
||||
include_public: bool = False,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[InputPromptSnapshot]:
|
||||
user_prompts = fetch_input_prompts_by_user(
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
include_public=include_public,
|
||||
)
|
||||
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]
|
||||
|
||||
|
||||
@basic_router.get("/{input_prompt_id}")
|
||||
def get_input_prompt(
|
||||
input_prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
input_prompt = fetch_input_prompt_by_id(
|
||||
id=input_prompt_id,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
return InputPromptSnapshot.from_model(input_prompt=input_prompt)
|
||||
|
||||
|
||||
@basic_router.post("")
|
||||
def create_input_prompt(
|
||||
create_input_prompt_request: CreateInputPromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
input_prompt = insert_input_prompt(
|
||||
prompt=create_input_prompt_request.prompt,
|
||||
content=create_input_prompt_request.content,
|
||||
is_public=create_input_prompt_request.is_public,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
return InputPromptSnapshot.from_model(input_prompt)
|
||||
|
||||
|
||||
@basic_router.patch("/{input_prompt_id}")
|
||||
def patch_input_prompt(
|
||||
input_prompt_id: int,
|
||||
update_input_prompt_request: UpdateInputPromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
try:
|
||||
updated_input_prompt = update_input_prompt(
|
||||
user=user,
|
||||
input_prompt_id=input_prompt_id,
|
||||
prompt=update_input_prompt_request.prompt,
|
||||
content=update_input_prompt_request.content,
|
||||
active=update_input_prompt_request.active,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while updated input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
return InputPromptSnapshot.from_model(updated_input_prompt)
|
||||
|
||||
|
||||
@basic_router.delete("/{input_prompt_id}")
|
||||
def delete_input_prompt(
|
||||
input_prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
remove_input_prompt(user, input_prompt_id, db_session)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while deleting input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.delete("/{input_prompt_id}")
|
||||
def delete_public_input_prompt(
|
||||
input_prompt_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
remove_public_input_prompt(input_prompt_id, db_session)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while deleting input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.get("")
|
||||
def list_public_input_prompts(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[InputPromptSnapshot]:
|
||||
user_prompts = fetch_public_input_prompts(
|
||||
db_session=db_session,
|
||||
)
|
||||
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]
|
||||
@@ -1,47 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.db.models import InputPrompt
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CreateInputPromptRequest(BaseModel):
|
||||
prompt: str
|
||||
content: str
|
||||
is_public: bool
|
||||
|
||||
|
||||
class UpdateInputPromptRequest(BaseModel):
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
|
||||
|
||||
class InputPromptResponse(BaseModel):
|
||||
id: int
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
|
||||
|
||||
class InputPromptSnapshot(BaseModel):
|
||||
id: int
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
user_id: UUID | None
|
||||
is_public: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, input_prompt: InputPrompt) -> "InputPromptSnapshot":
|
||||
return InputPromptSnapshot(
|
||||
id=input_prompt.id,
|
||||
prompt=input_prompt.prompt,
|
||||
content=input_prompt.content,
|
||||
active=input_prompt.active,
|
||||
user_id=input_prompt.user_id,
|
||||
is_public=input_prompt.is_public,
|
||||
)
|
||||
@@ -1,216 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.danswerbot.slack.config import validate_channel_names
|
||||
from danswer.danswerbot.slack.tokens import fetch_tokens
|
||||
from danswer.danswerbot.slack.tokens import save_tokens
|
||||
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.slack_bot_config import create_slack_bot_persona
|
||||
from danswer.db.slack_bot_config import fetch_slack_bot_config
|
||||
from danswer.db.slack_bot_config import fetch_slack_bot_configs
|
||||
from danswer.db.slack_bot_config import insert_slack_bot_config
|
||||
from danswer.db.slack_bot_config import remove_slack_bot_config
|
||||
from danswer.db.slack_bot_config import update_slack_bot_config
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.server.manage.models import SlackBotConfig
|
||||
from danswer.server.manage.models import SlackBotConfigCreationRequest
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
def _form_channel_config(
|
||||
slack_bot_config_creation_request: SlackBotConfigCreationRequest,
|
||||
current_slack_bot_config_id: int | None,
|
||||
db_session: Session,
|
||||
) -> ChannelConfig:
|
||||
raw_channel_names = slack_bot_config_creation_request.channel_names
|
||||
respond_tag_only = slack_bot_config_creation_request.respond_tag_only
|
||||
respond_member_group_list = (
|
||||
slack_bot_config_creation_request.respond_member_group_list
|
||||
)
|
||||
answer_filters = slack_bot_config_creation_request.answer_filters
|
||||
follow_up_tags = slack_bot_config_creation_request.follow_up_tags
|
||||
|
||||
if not raw_channel_names:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Must provide at least one channel name",
|
||||
)
|
||||
|
||||
try:
|
||||
cleaned_channel_names = validate_channel_names(
|
||||
channel_names=raw_channel_names,
|
||||
current_slack_bot_config_id=current_slack_bot_config_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
if respond_tag_only and respond_member_group_list:
|
||||
raise ValueError(
|
||||
"Cannot set DanswerBot to only respond to tags only and "
|
||||
"also respond to a predetermined set of users."
|
||||
)
|
||||
|
||||
channel_config: ChannelConfig = {
|
||||
"channel_names": cleaned_channel_names,
|
||||
}
|
||||
if respond_tag_only is not None:
|
||||
channel_config["respond_tag_only"] = respond_tag_only
|
||||
if respond_member_group_list:
|
||||
channel_config["respond_member_group_list"] = respond_member_group_list
|
||||
if answer_filters:
|
||||
channel_config["answer_filters"] = answer_filters
|
||||
if follow_up_tags is not None:
|
||||
channel_config["follow_up_tags"] = follow_up_tags
|
||||
|
||||
channel_config[
|
||||
"respond_to_bots"
|
||||
] = slack_bot_config_creation_request.respond_to_bots
|
||||
|
||||
return channel_config
|
||||
|
||||
|
||||
@router.post("/admin/slack-bot/config")
|
||||
def create_slack_bot_config(
|
||||
slack_bot_config_creation_request: SlackBotConfigCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> SlackBotConfig:
|
||||
channel_config = _form_channel_config(
|
||||
slack_bot_config_creation_request, None, db_session
|
||||
)
|
||||
|
||||
persona_id = None
|
||||
if slack_bot_config_creation_request.persona_id is not None:
|
||||
persona_id = slack_bot_config_creation_request.persona_id
|
||||
elif slack_bot_config_creation_request.document_sets:
|
||||
persona_id = create_slack_bot_persona(
|
||||
db_session=db_session,
|
||||
channel_names=channel_config["channel_names"],
|
||||
document_set_ids=slack_bot_config_creation_request.document_sets,
|
||||
existing_persona_id=None,
|
||||
).id
|
||||
|
||||
slack_bot_config_model = insert_slack_bot_config(
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_bot_config_creation_request.response_type,
|
||||
# XXX this is going away soon
|
||||
standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories,
|
||||
db_session=db_session,
|
||||
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,
|
||||
)
|
||||
return SlackBotConfig.from_model(slack_bot_config_model)
|
||||
|
||||
|
||||
@router.patch("/admin/slack-bot/config/{slack_bot_config_id}")
|
||||
def patch_slack_bot_config(
|
||||
slack_bot_config_id: int,
|
||||
slack_bot_config_creation_request: SlackBotConfigCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> SlackBotConfig:
|
||||
channel_config = _form_channel_config(
|
||||
slack_bot_config_creation_request, slack_bot_config_id, db_session
|
||||
)
|
||||
|
||||
persona_id = None
|
||||
if slack_bot_config_creation_request.persona_id is not None:
|
||||
persona_id = slack_bot_config_creation_request.persona_id
|
||||
elif slack_bot_config_creation_request.document_sets:
|
||||
existing_slack_bot_config = fetch_slack_bot_config(
|
||||
db_session=db_session, slack_bot_config_id=slack_bot_config_id
|
||||
)
|
||||
if existing_slack_bot_config is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Slack bot config not found",
|
||||
)
|
||||
|
||||
existing_persona_id = existing_slack_bot_config.persona_id
|
||||
if existing_persona_id is not None:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=existing_persona_id,
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
|
||||
if not persona.name.startswith(SLACK_BOT_PERSONA_PREFIX):
|
||||
# Don't update actual non-slackbot specific personas
|
||||
# Since this one specified document sets, we have to create a new persona
|
||||
# for this DanswerBot config
|
||||
existing_persona_id = None
|
||||
else:
|
||||
existing_persona_id = existing_slack_bot_config.persona_id
|
||||
|
||||
persona_id = create_slack_bot_persona(
|
||||
db_session=db_session,
|
||||
channel_names=channel_config["channel_names"],
|
||||
document_set_ids=slack_bot_config_creation_request.document_sets,
|
||||
existing_persona_id=existing_persona_id,
|
||||
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,
|
||||
).id
|
||||
|
||||
slack_bot_config_model = update_slack_bot_config(
|
||||
slack_bot_config_id=slack_bot_config_id,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_bot_config_creation_request.response_type,
|
||||
standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories,
|
||||
db_session=db_session,
|
||||
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,
|
||||
)
|
||||
return SlackBotConfig.from_model(slack_bot_config_model)
|
||||
|
||||
|
||||
@router.delete("/admin/slack-bot/config/{slack_bot_config_id}")
|
||||
def delete_slack_bot_config(
|
||||
slack_bot_config_id: int,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
remove_slack_bot_config(
|
||||
slack_bot_config_id=slack_bot_config_id, user=user, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/slack-bot/config")
|
||||
def list_slack_bot_configs(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[SlackBotConfig]:
|
||||
slack_bot_config_models = fetch_slack_bot_configs(db_session=db_session)
|
||||
return [
|
||||
SlackBotConfig.from_model(slack_bot_config_model)
|
||||
for slack_bot_config_model in slack_bot_config_models
|
||||
]
|
||||
|
||||
|
||||
@router.put("/admin/slack-bot/tokens")
|
||||
def put_tokens(
|
||||
tokens: SlackBotTokens,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
save_tokens(tokens=tokens)
|
||||
|
||||
|
||||
@router.get("/admin/slack-bot/tokens")
|
||||
def get_tokens(_: User | None = Depends(current_admin_user)) -> SlackBotTokens:
|
||||
try:
|
||||
return fetch_tokens()
|
||||
except KvKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="No tokens found")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user