forked from github/onyx
Compare commits
460 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa7c811a9a | ||
|
|
3c2fb21c11 | ||
|
|
1b55e617ad | ||
|
|
1c4f7fe7ef | ||
|
|
4629df06ef | ||
|
|
7d11f5ffb8 | ||
|
|
591e9831e7 | ||
|
|
01bd1a84c4 | ||
|
|
236fa947ee | ||
|
|
6b5c20dd54 | ||
|
|
d5168deac8 | ||
|
|
37110df2de | ||
|
|
517c27c5ed | ||
|
|
81f53ff3d8 | ||
|
|
1a1c91a7d9 | ||
|
|
cd8d8def1e | ||
|
|
5a056f1c0c | ||
|
|
0fb3fb8a1f | ||
|
|
35fe86e931 | ||
|
|
4d6b3c8f08 | ||
|
|
2362c2bdcc | ||
|
|
62000c1e46 | ||
|
|
c903d92fcc | ||
|
|
988e9aa682 | ||
|
|
6768c24723 | ||
|
|
b3b88f05d3 | ||
|
|
e54ce779fd | ||
|
|
4c9709ae4a | ||
|
|
c435bf3854 | ||
|
|
bb2b517124 | ||
|
|
dc2f4297b5 | ||
|
|
0060a1dd58 | ||
|
|
29e74c0877 | ||
|
|
779c2829bf | ||
|
|
6a2b7514fe | ||
|
|
8b9e6a91a4 | ||
|
|
b076c3d1ea | ||
|
|
d75ca0542a | ||
|
|
ce12dd4a5a | ||
|
|
0a9b854667 | ||
|
|
159453f8d7 | ||
|
|
2138c0b69d | ||
|
|
4b45164496 | ||
|
|
c0c9c67534 | ||
|
|
a4053501d0 | ||
|
|
60a16fa46d | ||
|
|
0ce992e22e | ||
|
|
35105f951b | ||
|
|
f1a5460739 | ||
|
|
824677ca75 | ||
|
|
cf4ede2130 | ||
|
|
81c33cc325 | ||
|
|
ec93ad9e6d | ||
|
|
d0fa02c8dc | ||
|
|
d6d83e79f1 | ||
|
|
e94fd8b022 | ||
|
|
92628357df | ||
|
|
50086526e2 | ||
|
|
7174ea3908 | ||
|
|
d07647c597 | ||
|
|
3a6712e3a0 | ||
|
|
bcc40224fa | ||
|
|
5d26290c5d | ||
|
|
9d1aa7401e | ||
|
|
c2b34f623c | ||
|
|
692fdb4597 | ||
|
|
2c38033ef5 | ||
|
|
777521a437 | ||
|
|
0e793e972b | ||
|
|
a2a171999a | ||
|
|
5504c9f289 | ||
|
|
5edc464c9a | ||
|
|
1670d923aa | ||
|
|
1981a02473 | ||
|
|
4dc8eab014 | ||
|
|
3a8d89afd3 | ||
|
|
fa879f7d7f | ||
|
|
f5be0cc2c0 | ||
|
|
621967d2b6 | ||
|
|
44905d36e5 | ||
|
|
53add2c801 | ||
|
|
d17426749d | ||
|
|
d099b931d8 | ||
|
|
4cd9122ba5 | ||
|
|
22fb7c3352 | ||
|
|
4ff3bee605 | ||
|
|
7029bdb291 | ||
|
|
cf4c3c57ed | ||
|
|
1b6eb0a52f | ||
|
|
503a709e37 | ||
|
|
0fd36c3120 | ||
|
|
8d12c7c202 | ||
|
|
a4d5ac816e | ||
|
|
2a139fd529 | ||
|
|
54347e100f | ||
|
|
936e69bc2b | ||
|
|
0056cdcf44 | ||
|
|
1791edec03 | ||
|
|
6201c1b585 | ||
|
|
c8f34e3103 | ||
|
|
77b0d76f53 | ||
|
|
733626f277 | ||
|
|
1da79c8627 | ||
|
|
4e3d57b1b9 | ||
|
|
e473ad0412 | ||
|
|
7efd3ba42f | ||
|
|
879e873310 | ||
|
|
adc747e66c | ||
|
|
a29c1ff05c | ||
|
|
49415e4615 | ||
|
|
885e698d5d | ||
|
|
30983657ec | ||
|
|
6b6b3daab7 | ||
|
|
20441df4a4 | ||
|
|
d7141df5fc | ||
|
|
615bb7b095 | ||
|
|
e759718c3e | ||
|
|
06d8d0e53c | ||
|
|
ae9b556876 | ||
|
|
f883611e94 | ||
|
|
13c536c033 | ||
|
|
2e6be57880 | ||
|
|
b352d83b8c | ||
|
|
aa67768c79 | ||
|
|
6004e540f3 | ||
|
|
64d2cea396 | ||
|
|
b5947a1c74 | ||
|
|
cdf260b277 | ||
|
|
73483b5e09 | ||
|
|
a6a444f365 | ||
|
|
449a403c73 | ||
|
|
4aebf824d2 | ||
|
|
26946198de | ||
|
|
e5035b8992 | ||
|
|
2e9af3086a | ||
|
|
dab3ba8a41 | ||
|
|
1e84b0daa4 | ||
|
|
f4c8abdf21 | ||
|
|
ccc5bb1e67 | ||
|
|
c3cf9134bb | ||
|
|
0370b9b38d | ||
|
|
95bf1c13ad | ||
|
|
00c1f93b12 | ||
|
|
a122510cee | ||
|
|
dca4f7a72b | ||
|
|
535dc265c5 | ||
|
|
56882367ba | ||
|
|
d9fbd7ffe2 | ||
|
|
8b7d01fb3b | ||
|
|
016a087b10 | ||
|
|
241b886976 | ||
|
|
ff014e4f5a | ||
|
|
0318507911 | ||
|
|
6650f01dc6 | ||
|
|
962e3f726a | ||
|
|
25a73b9921 | ||
|
|
dc0b3672ac | ||
|
|
c4ad03a65d | ||
|
|
c6f354fd03 | ||
|
|
2f001c23b7 | ||
|
|
4d950aa60d | ||
|
|
56406a0b53 | ||
|
|
eb31c08461 | ||
|
|
26f94c9890 | ||
|
|
a9570e01e2 | ||
|
|
402d83e167 | ||
|
|
10dcd49fc8 | ||
|
|
0fdad0e777 | ||
|
|
fab767d794 | ||
|
|
7dd70ca4c0 | ||
|
|
370760eeee | ||
|
|
24a62cb33d | ||
|
|
9e4a4ddf39 | ||
|
|
c281859509 | ||
|
|
2180a40bd3 | ||
|
|
997f9c3191 | ||
|
|
677c32ea79 | ||
|
|
edfc849652 | ||
|
|
9d296b623b | ||
|
|
5957b888a5 | ||
|
|
c7a91b1819 | ||
|
|
a099f8e296 | ||
|
|
16c8969028 | ||
|
|
65fde8f1b3 | ||
|
|
229db47e5d | ||
|
|
2e3397feb0 | ||
|
|
d5658ce477 | ||
|
|
ddf3f99da4 | ||
|
|
56785e6065 | ||
|
|
26e808d2a1 | ||
|
|
e3ac373f05 | ||
|
|
9e9a578921 | ||
|
|
f7172612e1 | ||
|
|
5aa2de7a40 | ||
|
|
e0b87d9d4e | ||
|
|
5607fdcddd | ||
|
|
651de071f7 | ||
|
|
5629ca7d96 | ||
|
|
bc403d97f2 | ||
|
|
292c78b193 | ||
|
|
ac35719038 | ||
|
|
02095e9281 | ||
|
|
8954a04602 | ||
|
|
8020db9e9a | ||
|
|
17c2f06338 | ||
|
|
9cff294a71 | ||
|
|
e983aaeca7 | ||
|
|
7ea774f35b | ||
|
|
d1846823ba | ||
|
|
fda89ac810 | ||
|
|
006fd4c438 | ||
|
|
9b7069a043 | ||
|
|
c64c25b2e1 | ||
|
|
c2727a3f19 | ||
|
|
37daf4f3e4 | ||
|
|
fcb7f6fcc0 | ||
|
|
429016d4a2 | ||
|
|
c83a450ec4 | ||
|
|
187b94a7d8 | ||
|
|
30225fd4c5 | ||
|
|
a4f053fa5b | ||
|
|
eab4fe83a0 | ||
|
|
78d1ae0379 | ||
|
|
87beb1f4d1 | ||
|
|
05c2b7d34e | ||
|
|
39d09a162a | ||
|
|
d291fea020 | ||
|
|
2665bff78e | ||
|
|
65d38ac8c3 | ||
|
|
8391d89bea | ||
|
|
ac2ed31726 | ||
|
|
47f947b045 | ||
|
|
63b051b342 | ||
|
|
a5729e2fa6 | ||
|
|
3cec854c5c | ||
|
|
26c6651a03 | ||
|
|
13001ede98 | ||
|
|
fda377a2fa | ||
|
|
bdfb894507 | ||
|
|
35c3511daa | ||
|
|
c1e19d0d93 | ||
|
|
e78aefb408 | ||
|
|
aa2e859b46 | ||
|
|
c0c8ae6c08 | ||
|
|
1225c663eb | ||
|
|
e052d607d5 | ||
|
|
8e5e11a554 | ||
|
|
57f0323f52 | ||
|
|
6e9f31d1e9 | ||
|
|
eeb844e35e | ||
|
|
d6a84ab413 | ||
|
|
68160d49dd | ||
|
|
0cc3d65839 | ||
|
|
df37387146 | ||
|
|
f72825cd46 | ||
|
|
6fb07d20cc | ||
|
|
b258ec1bed | ||
|
|
4fd55b8928 | ||
|
|
b3ea53fa46 | ||
|
|
fa0d19cc8c | ||
|
|
d5916e420c | ||
|
|
39b912befd | ||
|
|
37c5f24d91 | ||
|
|
ae72cd56f8 | ||
|
|
be5ef77896 | ||
|
|
0ed8f14015 | ||
|
|
a03e443541 | ||
|
|
4935459798 | ||
|
|
efb52873dd | ||
|
|
442f7595cc | ||
|
|
81cbcbb403 | ||
|
|
0a0e672b35 | ||
|
|
69644b266e | ||
|
|
5a4820c55f | ||
|
|
a5d69bb392 | ||
|
|
23ee45c033 | ||
|
|
31bfd015ae | ||
|
|
0125d8a0f6 | ||
|
|
4f64444f0f | ||
|
|
abf9cc3248 | ||
|
|
f5bf2e6374 | ||
|
|
24b3b1fa9e | ||
|
|
7433dddac3 | ||
|
|
fe938b6fc6 | ||
|
|
2db029672b | ||
|
|
602f9c4a0a | ||
|
|
551705ad62 | ||
|
|
d9581ce0ae | ||
|
|
e27800d501 | ||
|
|
927dffecb5 | ||
|
|
68b23b6339 | ||
|
|
174f54473e | ||
|
|
329824ab22 | ||
|
|
b0f76b97ef | ||
|
|
80eedebe86 | ||
|
|
e8786e1a20 | ||
|
|
44e3dcb19f | ||
|
|
e8f778ccb5 | ||
|
|
d9adee168b | ||
|
|
73b653d324 | ||
|
|
9cd0c197e7 | ||
|
|
0b07d615b1 | ||
|
|
5c9c70dffb | ||
|
|
61c9343a7e | ||
|
|
53353f9b62 | ||
|
|
fbf7c642a3 | ||
|
|
d9e5795b36 | ||
|
|
acb60f67e1 | ||
|
|
4990aacc0d | ||
|
|
947d4d0a2e | ||
|
|
d7a90aeb2b | ||
|
|
ee0d092dcc | ||
|
|
c6663d83d5 | ||
|
|
0618b59de6 | ||
|
|
517a539d7e | ||
|
|
a1da4dfac6 | ||
|
|
e968e1d14b | ||
|
|
8215a7859a | ||
|
|
37bba3dbe9 | ||
|
|
52c0d6e68b | ||
|
|
08909b40b0 | ||
|
|
64ebaf2dda | ||
|
|
815c30c9d0 | ||
|
|
57ecab0098 | ||
|
|
26b491fb0c | ||
|
|
bfa338e142 | ||
|
|
e744c6b75a | ||
|
|
7d6a41243c | ||
|
|
25814d7a23 | ||
|
|
59b16ac320 | ||
|
|
11d96b2807 | ||
|
|
fe117513b0 | ||
|
|
fcce2b5a60 | ||
|
|
ad6ea1679a | ||
|
|
fad311282b | ||
|
|
ca0f186b0e | ||
|
|
c9edc2711c | ||
|
|
dcbb7b85d9 | ||
|
|
2df9f4d7fc | ||
|
|
7bc34ce182 | ||
|
|
76275b29d4 | ||
|
|
604e511c09 | ||
|
|
379e71160a | ||
|
|
a8b7155b5e | ||
|
|
9a51745fc9 | ||
|
|
fbb05e630d | ||
|
|
ef2b445201 | ||
|
|
17bd68be4c | ||
|
|
890eb7901e | ||
|
|
0a6c2afb8a | ||
|
|
7ffba2aa60 | ||
|
|
816ec5e3ca | ||
|
|
3554e29b8d | ||
|
|
88eaae62d9 | ||
|
|
a014cb7792 | ||
|
|
89807c8c05 | ||
|
|
8403b94722 | ||
|
|
4fa96788f6 | ||
|
|
e279918f95 | ||
|
|
8e3258981e | ||
|
|
b14b220d89 | ||
|
|
cc5d27bff7 | ||
|
|
5ddc9b34ab | ||
|
|
a7099a1917 | ||
|
|
47ab273353 | ||
|
|
7be3730038 | ||
|
|
f6982b03b6 | ||
|
|
76f1f17710 | ||
|
|
bc1de6562d | ||
|
|
93d4eef61d | ||
|
|
4ffbdbb8b0 | ||
|
|
764aab3e53 | ||
|
|
7c34744655 | ||
|
|
2037e11495 | ||
|
|
6a449f1fb1 | ||
|
|
d9076a6ff6 | ||
|
|
1bd76f528f | ||
|
|
a5d2759fbc | ||
|
|
022f59e5b2 | ||
|
|
5bf998219e | ||
|
|
5da81a3d0d | ||
|
|
e73739547a | ||
|
|
e519dfc849 | ||
|
|
bf5844578c | ||
|
|
37e9ccf864 | ||
|
|
bb9a18b22c | ||
|
|
b5982c10c3 | ||
|
|
a7ddb22e50 | ||
|
|
595f61ea3a | ||
|
|
d2f7dff464 | ||
|
|
ae0dbfadc6 | ||
|
|
38d516cc7a | ||
|
|
7f029a0304 | ||
|
|
2c867b5143 | ||
|
|
af510cc965 | ||
|
|
f0337d2eba | ||
|
|
17e00b186e | ||
|
|
dbf59d2acc | ||
|
|
41964031bf | ||
|
|
a7578c9707 | ||
|
|
51490b5cd9 | ||
|
|
e6866c92cf | ||
|
|
8c61e6997b | ||
|
|
90828008e1 | ||
|
|
12442c1c06 | ||
|
|
876c6fdaa6 | ||
|
|
3e05c4fa67 | ||
|
|
90fbe1ab48 | ||
|
|
31d5fc6d31 | ||
|
|
fa460f4da1 | ||
|
|
e7cc0f235c | ||
|
|
091c2c8a80 | ||
|
|
3142e2eed2 | ||
|
|
5deb12523e | ||
|
|
744c95e1e1 | ||
|
|
0d505ffea1 | ||
|
|
30cdc5c9de | ||
|
|
dff7a4ba1e | ||
|
|
d95da554ea | ||
|
|
fb1fbbee5c | ||
|
|
f045bbed70 | ||
|
|
ca74884bd7 | ||
|
|
9b185f469f | ||
|
|
e8d3190770 | ||
|
|
a6e6be4037 | ||
|
|
7d3f8b7c8c | ||
|
|
c658ffd0b6 | ||
|
|
9425ccd043 | ||
|
|
829b50571d | ||
|
|
d09c320538 | ||
|
|
478fb4f999 | ||
|
|
0fd51409ad | ||
|
|
21aa233170 | ||
|
|
30efe3df88 | ||
|
|
09ba0a49b3 | ||
|
|
beb54eaa5d | ||
|
|
0632e92144 | ||
|
|
9c89ae78ba | ||
|
|
a85e73edbe | ||
|
|
5a63b689eb | ||
|
|
aee573cd76 | ||
|
|
ec7697fcfe | ||
|
|
b801937299 | ||
|
|
499dfb59da | ||
|
|
7cc54eed0f | ||
|
|
d04716c99d | ||
|
|
732f5efb12 | ||
|
|
29a0a45518 | ||
|
|
59bac1ca8f | ||
|
|
c2721c7889 | ||
|
|
ab65b19c4c | ||
|
|
c666f35cd0 | ||
|
|
dbe33959c0 | ||
|
|
829d04c904 | ||
|
|
351475de28 | ||
|
|
a808c733b8 | ||
|
|
2d06008f6f | ||
|
|
aa9071e441 | ||
|
|
22f2398269 | ||
|
|
1abce83626 |
@@ -1,4 +1,4 @@
|
||||
name: Build and Push Backend Images on Tagging
|
||||
name: Build and Push Backend Image on Tag
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -32,3 +32,11 @@ jobs:
|
||||
tags: |
|
||||
danswer/danswer-backend:${{ github.ref_name }}
|
||||
danswer/danswer-backend:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
42
.github/workflows/docker-build-push-model-server-container-on-tag.yml
vendored
Normal file
42
.github/workflows/docker-build-push-model-server-container-on-tag.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Build and Push Model Server Image on Tag
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
danswer/danswer-model-server:${{ github.ref_name }}
|
||||
danswer/danswer-model-server:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Build and Push Web Images on Tagging
|
||||
name: Build and Push Web Image on Tag
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -32,3 +32,11 @@ jobs:
|
||||
tags: |
|
||||
danswer/danswer-web-server:${{ github.ref_name }}
|
||||
danswer/danswer-web-server:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-web-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
35
.github/workflows/pr-python-tests.yml
vendored
Normal file
35
.github/workflows/pr-python-tests.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Python Unit Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
backend-check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/unit
|
||||
21
.github/workflows/pr-quality-checks.yml
vendored
Normal file
21
.github/workflows/pr-quality-checks.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Quality Checks PR
|
||||
concurrency:
|
||||
group: Quality-Checks-PR-${{ github.head_ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request: null
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
with:
|
||||
extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }}
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,2 +1,7 @@
|
||||
.env
|
||||
.DS_store
|
||||
.venv
|
||||
.mypy_cache
|
||||
.idea
|
||||
/deployment/data/nginx/app.conf
|
||||
.vscode/launch.json
|
||||
|
||||
@@ -28,6 +28,13 @@ repos:
|
||||
rev: v0.0.286
|
||||
hooks:
|
||||
- id: ruff
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
additional_dependencies:
|
||||
- prettier
|
||||
|
||||
# We would like to have a mypy pre-commit hook, but due to the fact that
|
||||
# pre-commit runs in it's own isolated environment, we would need to install
|
||||
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
|
||||
Copy this file into '.vscode/launch.json' or merge its
|
||||
contents into your existing configurations.
|
||||
|
||||
*/
|
||||
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
@@ -5,7 +12,7 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Danswer backend",
|
||||
"name": "API Server",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
@@ -14,7 +21,7 @@
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"DISABLE_AUTH": "True",
|
||||
"TYPESENSE_API_KEY": "typesense_api_key",
|
||||
"DYNAMIC_CONFIG_DIR_PATH": "./dynamic_config_storage",
|
||||
"DYNAMIC_CONFIG_DIR_PATH": "./dynamic_config_storage"
|
||||
},
|
||||
"args": [
|
||||
"danswer.main:app",
|
||||
@@ -24,7 +31,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Danswer indexer",
|
||||
"name": "Indexer",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
@@ -33,33 +40,43 @@
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONPATH": ".",
|
||||
"TYPESENSE_API_KEY": "typesense_api_key",
|
||||
"DYNAMIC_CONFIG_DIR_PATH": "./dynamic_config_storage",
|
||||
},
|
||||
"DYNAMIC_CONFIG_DIR_PATH": "./dynamic_config_storage"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Danswer temp files deletion",
|
||||
"name": "Temp File Deletion",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/file_deletion.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONPATH": "${workspaceFolder}/backend",
|
||||
},
|
||||
"PYTHONPATH": "${workspaceFolder}/backend"
|
||||
}
|
||||
},
|
||||
// For the listner to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Danswer slack bot listener",
|
||||
"name": "Slack Bot Listener",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/listeners/slack_listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONPATH": ".",
|
||||
},
|
||||
"LOG_LEVEL": "DEBUG"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
||||
108
CONTRIBUTING.md
108
CONTRIBUTING.md
@@ -1,3 +1,5 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md"} -->
|
||||
|
||||
# Contributing to Danswer
|
||||
Hey there! We are so excited that you're interested in Danswer.
|
||||
|
||||
@@ -6,7 +8,7 @@ As an open source project in a rapidly changing space, we welcome all contributi
|
||||
|
||||
## 💃 Guidelines
|
||||
### Contribution Opportunities
|
||||
The [GitHub issues](https://github.com/danswer-ai/danswer/issues) page is a great place to start for contribution ideas.
|
||||
The [GitHub Issues](https://github.com/danswer-ai/danswer/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
@@ -19,7 +21,9 @@ If you have a new/different contribution in mind, we'd love to hear about it!
|
||||
Your input is vital to making sure that Danswer moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
And always feel free to message us (Chris Weaver / Yuhong Sun) on Slack / Discord directly about anything at all.
|
||||
And always feel free to message us (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
|
||||
|
||||
### Contributing Code
|
||||
@@ -36,7 +40,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1u3h3ke3b-VGh1idW19R8oiNRiKBYv2w)
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ)
|
||||
and
|
||||
[Discord](https://discord.gg/TDJ59cGV2X).
|
||||
|
||||
@@ -44,8 +48,8 @@ We would love to see you there!
|
||||
|
||||
|
||||
## Get Started 🚀
|
||||
Danswer being a fully functional app, relies on several external pieces of software, specifically:
|
||||
- Postgres (Relational DB)
|
||||
Danswer being a fully functional app, relies on some external pieces of software, specifically:
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
|
||||
This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for
|
||||
@@ -54,11 +58,10 @@ development purposes but also feel free to just use the containers and update wi
|
||||
|
||||
|
||||
### Local Set Up
|
||||
We've tested primarily with Python versions >= 3.11 but the code should work with Python >= 3.9.
|
||||
It is recommended to use Python version 3.11
|
||||
|
||||
This guide skips a few optional features for simplicity, reach out if you need any of these:
|
||||
- User Authentication feature
|
||||
- File Connector background job
|
||||
If using a lower version, modifications will have to be made to the code.
|
||||
If using a higher version, the version of Tensorflow we use may not be available for your platform.
|
||||
|
||||
|
||||
#### Installing Requirements
|
||||
@@ -86,25 +89,23 @@ Once the above is done, navigate to `danswer/web` run:
|
||||
npm i
|
||||
```
|
||||
|
||||
Install Playwright (required by the Web Connector), with the python venv active, run:
|
||||
Install Playwright (required by the Web Connector)
|
||||
|
||||
> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again.
|
||||
This will update the path to include playwright
|
||||
|
||||
Then install Playwright by running:
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
|
||||
|
||||
#### Dependent Docker Containers
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up the containers with:
|
||||
|
||||
Postgres:
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up Vespa and Postgres with:
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d relational_db
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db
|
||||
```
|
||||
|
||||
Vespa:
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index
|
||||
```
|
||||
|
||||
(index refers to Vespa and relational_db refers to Postgres)
|
||||
|
||||
#### Running Danswer
|
||||
|
||||
@@ -115,71 +116,50 @@ mkdir dynamic_config_storage
|
||||
|
||||
To start the frontend, navigate to `danswer/web` and run:
|
||||
```bash
|
||||
DISABLE_AUTH=true npm run dev
|
||||
```
|
||||
_for Windows, run:_
|
||||
```bash
|
||||
(SET "DISABLE_AUTH=true" && npm run dev)
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Package the Vespa schema. This will only need to be done when the Vespa schema is updated locally.
|
||||
|
||||
The first time running Danswer, you will need to run the DB migrations for Postgres.
|
||||
Navigate to `danswer/backend` and with the venv active, run:
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
Additionally, we have to package the Vespa schema deployment:
|
||||
Nagivate to `danswer/backend/danswer/datastores/vespa/app_config` and run:
|
||||
Navigate to `danswer/backend/danswer/document_index/vespa/app_config` and run:
|
||||
```bash
|
||||
zip -r ../vespa-app.zip .
|
||||
```
|
||||
- Note: If you don't have the `zip` utility, you will need to install it prior to running the above
|
||||
|
||||
The first time running Danswer, you will also need to run the DB migrations for Postgres.
|
||||
After the first time, this is no longer required unless the DB models change.
|
||||
|
||||
Navigate to `danswer/backend` and with the venv active, run:
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
Next, start the task queue which orchestrates the background jobs.
|
||||
Jobs that take more time are run async from the API server.
|
||||
|
||||
Still in `danswer/backend`, run:
|
||||
```bash
|
||||
python ./scripts/dev_run_background_jobs.py
|
||||
```
|
||||
|
||||
To run the backend API server, navigate back to `danswer/backend` and run:
|
||||
```bash
|
||||
DISABLE_AUTH=True \
|
||||
AUTH_TYPE=disabled \
|
||||
DYNAMIC_CONFIG_DIR_PATH=./dynamic_config_storage \
|
||||
VESPA_DEPLOYMENT_ZIP=./danswer/datastores/vespa/vespa-app.zip \
|
||||
VESPA_DEPLOYMENT_ZIP=./danswer/document_index/vespa/vespa-app.zip \
|
||||
uvicorn danswer.main:app --reload --port 8080
|
||||
```
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:DISABLE_AUTH='True'
|
||||
$env:AUTH_TYPE='disabled'
|
||||
$env:DYNAMIC_CONFIG_DIR_PATH='./dynamic_config_storage'
|
||||
$env:VESPA_DEPLOYMENT_ZIP='./danswer/datastores/vespa/vespa-app.zip'
|
||||
$env:VESPA_DEPLOYMENT_ZIP='./danswer/document_index/vespa/vespa-app.zip'
|
||||
uvicorn danswer.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
|
||||
To run the background job to check for connector updates and index documents, navigate to `danswer/backend` and run:
|
||||
```bash
|
||||
PYTHONPATH=. DYNAMIC_CONFIG_DIR_PATH=./dynamic_config_storage python danswer/background/update.py
|
||||
```
|
||||
_For Windows:_
|
||||
```bash
|
||||
powershell -Command " $env:PYTHONPATH='.'; $env:DYNAMIC_CONFIG_DIR_PATH='./dynamic_config_storage'; python danswer/background/update.py "
|
||||
```
|
||||
|
||||
To run the background job to check for periodically check for document set updates, navigate to `danswer/backend` and run:
|
||||
```bash
|
||||
PYTHONPATH=. DYNAMIC_CONFIG_DIR_PATH=./dynamic_config_storage python danswer/background/document_set_sync_script.py
|
||||
```
|
||||
_For Windows:_
|
||||
```bash
|
||||
powershell -Command " $env:PYTHONPATH='.'; $env:DYNAMIC_CONFIG_DIR_PATH='./dynamic_config_storage'; python danswer/background/document_set_sync_script.py "
|
||||
```
|
||||
|
||||
To run Celery, which handles deletion of connectors + syncing of document sets, navigate to `danswer/backend` and run:
|
||||
```bash
|
||||
PYTHONPATH=. DYNAMIC_CONFIG_DIR_PATH=./dynamic_config_storage celery -A danswer.background.celery worker --loglevel=info --concurrency=1
|
||||
```
|
||||
_For Windows:_
|
||||
```bash
|
||||
powershell -Command " $env:PYTHONPATH='.'; $env:DYNAMIC_CONFIG_DIR_PATH='./dynamic_config_storage'; celery -A danswer.background.celery worker --loglevel=info --concurrency=1 "
|
||||
```
|
||||
|
||||
Note: if you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
77
README.md
77
README.md
@@ -1,15 +1,17 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
<p align="center">OpenSource Enterprise Question-Answering</p>
|
||||
<p align="center">Open Source Unified Search and Gen-AI Chat with your Docs.</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1u5ycen3o-6SJbWfivLWP5LPyp_jftuw" target="_blank">
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
@@ -20,62 +22,85 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Danswer](https://www.danswer.ai/)</strong> allows you to ask natural language questions against internal documents and get back reliable answers backed by quotes and references from the source material so that you can always trust what you get back. You can connect to a number of common tools such as Slack, GitHub, Confluence, amongst others.
|
||||
<strong>[Danswer](https://www.danswer.ai/)</strong> lets you ask questions in natural language questions and get back
|
||||
answers based on your team specific documents. Think ChatGPT if it had access to your team's unique
|
||||
knowledge. Connects to all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
|
||||
Teams have used Danswer to:
|
||||
- Speedup customer support and escalation turnaround time.
|
||||
- Improve Engineering efficiency by making documentation and code changelogs easy to find.
|
||||
- Let sales team get fuller context and faster in preparation for calls.
|
||||
- Track customer requests and priorities for Product teams.
|
||||
- Help teams self-serve IT, Onboarding, HR, etc.
|
||||
|
||||
<h3>Usage</h3>
|
||||
|
||||
Danswer provides a fully-featured web UI:
|
||||
Danswer Web App:
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/25087905/619607a1-4ad2-41a0-9728-351752acc26e
|
||||
|
||||
|
||||
Or, if you prefer, you can plug Danswer into your existing Slack workflows (more integrations to come 😁):
|
||||
|
||||
Or, plug Danswer into your existing Slack workflows (more integrations to come 😁):
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
|
||||
|
||||
|
||||
For more details on the admin controls, check out our <strong><a href="https://www.youtube.com/watch?v=geNzY1nbCnU">Full Video Demo</a></strong>!
|
||||
For more details on the Admin UI to manage connectors and users, check out our
|
||||
<strong><a href="https://www.youtube.com/watch?v=geNzY1nbCnU">Full Video Demo</a></strong>!
|
||||
|
||||
<h3>Deployment</h3>
|
||||
## Deployment
|
||||
|
||||
Danswer can easily be tested locally or deployed on a virtual machine with a single `docker compose` command. Checkout our [docs](https://docs.danswer.dev/quickstart) to learn more.
|
||||
Danswer can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.danswer.dev/quickstart) to learn more.
|
||||
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/danswer-ai/danswer/tree/main/deployment/kubernetes).
|
||||
|
||||
## 💃 Features
|
||||
* Direct QA powered by Generative AI models with answers backed by quotes and source links.
|
||||
* Intelligent Document Retrieval (Semantic Search/Reranking) using the latest LLMs.
|
||||
* An AI Helper backed by a custom Deep Learning model to interpret user intent.
|
||||
* User authentication and document level access management.
|
||||
* Support for an LLM of your choice (GPT-4, Llama2, Orca, etc.)
|
||||
* Management Dashboard to manage connectors and set up features such as live update fetching.
|
||||
* One line Docker Compose (or Kubernetes) deployment to host Danswer anywhere.
|
||||
|
||||
## 🔌 Connectors
|
||||
## 💃 Main Features
|
||||
* Document Search + AI Answers for natural language queries.
|
||||
* Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
* Chat support (think ChatGPT but it has access to your private knowledge sources).
|
||||
* Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
* Slack integration to get answers and search results directly in Slack.
|
||||
|
||||
Danswer currently syncs documents (every 10 minutes) from:
|
||||
|
||||
## Other Noteable Benefits of Danswer
|
||||
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
* User Authentication with document level access management.
|
||||
* Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
* Custom deep learning models + learn from user feedback.
|
||||
* Connect Danswer with LLM of your choice for a fully airgapped solution.
|
||||
* Easy deployment and ability to host Danswer anywhere of your choosing.
|
||||
|
||||
|
||||
## 🔌 Connectors
|
||||
Efficiently pulls the latest changes from:
|
||||
* Slack
|
||||
* GitHub
|
||||
* Google Drive
|
||||
* Confluence
|
||||
* Jira
|
||||
* Zendesk
|
||||
* Gmail
|
||||
* Notion
|
||||
* Gong
|
||||
* Slab
|
||||
* Linear
|
||||
* Productboard
|
||||
* Guru
|
||||
* Zulip
|
||||
* Bookstack
|
||||
* Document360
|
||||
* Sharepoint
|
||||
* Hubspot
|
||||
* Local Files
|
||||
* Websites
|
||||
* With more to come...
|
||||
* And more ...
|
||||
|
||||
## 🚧 Roadmap
|
||||
* Chat/Conversation support.
|
||||
* Organizational understanding.
|
||||
* Ability to locate and suggest experts.
|
||||
* Ability to locate and suggest experts from your team.
|
||||
* Code Search
|
||||
* Structured Query Languages (SQL, Excel formulas, etc.)
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
17
backend/.dockerignore
Normal file
17
backend/.dockerignore
Normal file
@@ -0,0 +1,17 @@
|
||||
**/__pycache__
|
||||
venv/
|
||||
env/
|
||||
*.egg-info
|
||||
.cache
|
||||
.git/
|
||||
.svn/
|
||||
.vscode/
|
||||
.idea/
|
||||
*.log
|
||||
log/
|
||||
.env
|
||||
secrets.yaml
|
||||
build/
|
||||
dist/
|
||||
.coverage
|
||||
htmlcov/
|
||||
5
backend/.gitignore
vendored
5
backend/.gitignore
vendored
@@ -1,10 +1,11 @@
|
||||
__pycache__/
|
||||
.mypy_cache
|
||||
.idea/
|
||||
site_crawls/
|
||||
.ipynb_checkpoints/
|
||||
api_keys.py
|
||||
*ipynb
|
||||
qdrant-data/
|
||||
typesense-data/
|
||||
.env
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule
|
||||
|
||||
@@ -1,61 +1,49 @@
|
||||
FROM python:3.11.4-slim-bookworm
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y git cmake pkg-config libprotobuf-c-dev protobuf-compiler \
|
||||
libprotobuf-dev libgoogle-perftools-dev libpq-dev build-essential cron curl \
|
||||
supervisor zip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
# libpq-dev needed for psycopg (postgres)
|
||||
# curl included just for users' convenience
|
||||
# zip for Vespa step futher down
|
||||
# ca-certificates for HTTPS
|
||||
RUN apt-get update && \
|
||||
apt-get install -y cmake curl zip ca-certificates && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
RUN pip uninstall -y py
|
||||
|
||||
RUN playwright install chromium
|
||||
RUN playwright install-deps chromium
|
||||
|
||||
# install nodejs and replace nodejs packaged with playwright (18.17.0) with the one installed below
|
||||
# based on the instructions found here:
|
||||
# https://nodejs.org/en/download/package-manager#debian-and-ubuntu-based-linux-distributions
|
||||
# this is temporarily needed until playwright updates their packaged node version to
|
||||
# 20.5.1+
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y ca-certificates curl gnupg
|
||||
RUN mkdir -p /etc/apt/keyrings
|
||||
RUN curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg
|
||||
RUN echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list
|
||||
RUN apt-get update
|
||||
RUN apt-get install nodejs -y
|
||||
# replace nodejs packaged with playwright (18.17.0) with the one installed above
|
||||
RUN cp /usr/bin/node /usr/local/lib/python3.11/site-packages/playwright/driver/node
|
||||
# remove nodejs (except for the binary we moved into playwright)
|
||||
RUN apt-get remove -y nodejs
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
||||
pip uninstall -y py && \
|
||||
playwright install chromium && playwright install-deps chromium && \
|
||||
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
|
||||
|
||||
# Cleanup for CVEs and size reduction
|
||||
RUN apt-get remove -y linux-libc-dev \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Remove tornado test key to placate vulnerability scanners
|
||||
# More details can be found here:
|
||||
# https://github.com/tornadoweb/tornado/issues/3107
|
||||
RUN rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake libldap-2.5-0 libldap-2.5-0 && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./shared_models /app/shared_models
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
WORKDIR /app/danswer/datastores/vespa/app_config
|
||||
RUN zip -r /app/danswer/vespa-app.zip .
|
||||
WORKDIR /app
|
||||
|
||||
# TODO: remove this once all users have migrated
|
||||
COPY ./scripts/migrate_vespa_to_acl.py /app/migrate_vespa_to_acl.py
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
|
||||
# By default this container does nothing, it is used by api server and background which specify their own CMD
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
39
backend/Dockerfile.model_server
Normal file
39
backend/Dockerfile.model_server
Normal file
@@ -0,0 +1,39 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
|
||||
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Needed for model configs and defaults
|
||||
COPY ./danswer/configs /app/danswer/configs
|
||||
COPY ./danswer/dynamic_configs /app/danswer/dynamic_configs
|
||||
|
||||
# Utils used by model server
|
||||
COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
|
||||
COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py
|
||||
COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py
|
||||
|
||||
# Place to fetch version information
|
||||
COPY ./danswer/__init__.py /app/danswer/__init__.py
|
||||
|
||||
# Shared implementations for running NLP models locally
|
||||
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
|
||||
|
||||
# Request/Response models
|
||||
COPY ./shared_models /app/shared_models
|
||||
|
||||
# Model Server main code
|
||||
COPY ./model_server /app/model_server
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
|
||||
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]
|
||||
@@ -1,4 +1,8 @@
|
||||
Generic single-database configuration with an async dbapi.
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/backend/alembic/README.md"} -->
|
||||
|
||||
# Alembic DB Migrations
|
||||
These files are for creating/updating the tables in the Relational DB (Postgres).
|
||||
Danswer migrations use a generic single-database configuration with an async dbapi.
|
||||
|
||||
## To generate new migrations:
|
||||
run from danswer/backend:
|
||||
@@ -7,7 +11,6 @@ run from danswer/backend:
|
||||
More info can be found here: https://alembic.sqlalchemy.org/en/latest/autogenerate.html
|
||||
|
||||
## Running migrations
|
||||
|
||||
To run all un-applied migrations:
|
||||
`alembic upgrade head`
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from danswer.db.models import Base
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@@ -21,7 +22,7 @@ if config.config_file_name is not None:
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = Base.metadata
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
@@ -44,7 +45,7 @@ def run_migrations_offline() -> None:
|
||||
url = build_connection_string()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
@@ -54,7 +55,7 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Introduce Danswer APIs
|
||||
|
||||
Revision ID: 15326fcec57e
|
||||
Revises: 77d07dffae64
|
||||
Create Date: 2023-11-11 20:51:24.228999
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "15326fcec57e"
|
||||
down_revision = "77d07dffae64"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("credential", "is_admin", new_column_name="admin_public")
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("from_ingestion_api", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.alter_column(
|
||||
"connector",
|
||||
"source",
|
||||
type_=sa.String(length=50),
|
||||
existing_type=sa.Enum(DocumentSource, native_enum=False),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "from_ingestion_api")
|
||||
op.alter_column("credential", "admin_public", new_column_name="is_admin")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Persona Datetime Aware
|
||||
|
||||
Revision ID: 30c1d5744104
|
||||
Revises: 7f99be1cb9f5
|
||||
Create Date: 2023-10-16 23:21:01.283424
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "30c1d5744104"
|
||||
down_revision = "7f99be1cb9f5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("datetime_aware", sa.Boolean(), nullable=True))
|
||||
op.execute("UPDATE persona SET datetime_aware = TRUE")
|
||||
op.alter_column("persona", "datetime_aware", nullable=False)
|
||||
op.create_index(
|
||||
"_default_persona_name_idx",
|
||||
"persona",
|
||||
["name"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("default_persona = true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"_default_persona_name_idx",
|
||||
table_name="persona",
|
||||
postgresql_where=sa.text("default_persona = true"),
|
||||
)
|
||||
op.drop_column("persona", "datetime_aware")
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Move is_public to cc_pair
|
||||
|
||||
Revision ID: 3b25685ff73c
|
||||
Revises: e0a68a81d434
|
||||
Create Date: 2023-10-05 18:47:09.582849
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b25685ff73c"
|
||||
down_revision = "e0a68a81d434"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=True),
|
||||
)
|
||||
# fill in is_public for existing rows
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_public = true WHERE is_public IS NULL"
|
||||
)
|
||||
op.alter_column("connector_credential_pair", "is_public", nullable=False)
|
||||
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column("is_admin", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute("UPDATE credential SET is_admin = true WHERE is_admin IS NULL")
|
||||
op.alter_column("credential", "is_admin", nullable=False)
|
||||
|
||||
op.drop_column("credential", "public_doc")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column("public_doc", sa.Boolean(), nullable=True),
|
||||
)
|
||||
# setting public_doc to false for all existing rows to be safe
|
||||
# NOTE: this is likely not the correct state of the world but it's the best we can do
|
||||
op.execute("UPDATE credential SET public_doc = false WHERE public_doc IS NULL")
|
||||
op.alter_column("credential", "public_doc", nullable=False)
|
||||
op.drop_column("connector_credential_pair", "is_public")
|
||||
op.drop_column("credential", "is_admin")
|
||||
@@ -35,6 +35,7 @@ def upgrade() -> None:
|
||||
"SUCCESS",
|
||||
"FAILED",
|
||||
name="indexingstatus",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
|
||||
31
backend/alembic/versions/46625e4745d4_remove_native_enum.py
Normal file
31
backend/alembic/versions/46625e4745d4_remove_native_enum.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Remove Native Enum
|
||||
|
||||
Revision ID: 46625e4745d4
|
||||
Revises: 9d97fecfab7f
|
||||
Create Date: 2023-10-27 11:38:33.803145
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy import String
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "46625e4745d4"
|
||||
down_revision = "9d97fecfab7f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# At this point, we directly changed some previous migrations,
|
||||
# https://github.com/danswer-ai/danswer/pull/637
|
||||
# Due to using Postgres native Enums, it caused some complications for first time users.
|
||||
# To remove those complications, all Enums are only handled application side moving forward.
|
||||
# This migration exists to ensure that existing users don't run into upgrade issues.
|
||||
op.alter_column("index_attempt", "status", type_=String)
|
||||
op.alter_column("connector_credential_pair", "last_attempt_status", type_=String)
|
||||
op.execute("DROP TYPE IF EXISTS indexingstatus")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't want Native Enums, do nothing
|
||||
pass
|
||||
@@ -59,6 +59,7 @@ def upgrade() -> None:
|
||||
"SUCCESS",
|
||||
"FAILED",
|
||||
name="indexingstatus",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
@@ -70,4 +71,3 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("index_attempt")
|
||||
sa.Enum(name="indexingstatus").drop(op.get_bind(), checkfirst=False)
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Add additional retrieval controls to Persona
|
||||
|
||||
Revision ID: 50b683a8295c
|
||||
Revises: 7da0ae5ad583
|
||||
Create Date: 2023-11-27 17:23:29.668422
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "50b683a8295c"
|
||||
down_revision = "7da0ae5ad583"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("num_chunks", sa.Integer(), nullable=True))
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("apply_llm_relevance_filter", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "apply_llm_relevance_filter")
|
||||
op.drop_column("persona", "num_chunks")
|
||||
@@ -0,0 +1,32 @@
|
||||
"""CC-Pair Name not Unique
|
||||
|
||||
Revision ID: 76b60d407dfb
|
||||
Revises: b156fa702355
|
||||
Create Date: 2023-12-22 21:42:10.018804
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "76b60d407dfb"
|
||||
down_revision = "b156fa702355"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("DELETE FROM connector_credential_pair WHERE name IS NULL")
|
||||
op.drop_constraint(
|
||||
"connector_credential_pair__name__key",
|
||||
"connector_credential_pair",
|
||||
type_="unique",
|
||||
)
|
||||
op.alter_column(
|
||||
"connector_credential_pair", "name", existing_type=sa.String(), nullable=False
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# This wasn't really required by the code either, no good reason to make it unique again
|
||||
pass
|
||||
@@ -0,0 +1,35 @@
|
||||
"""forcibly remove more enum types from postgres
|
||||
|
||||
Revision ID: 77d07dffae64
|
||||
Revises: d61e513bef0a
|
||||
Create Date: 2023-11-01 12:33:01.999617
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy import String
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "77d07dffae64"
|
||||
down_revision = "d61e513bef0a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# In a PR:
|
||||
# https://github.com/danswer-ai/danswer/pull/397/files#diff-f05fb341f6373790b91852579631b64ca7645797a190837156a282b67e5b19c2
|
||||
# we directly changed some previous migrations. This caused some users to have native enums
|
||||
# while others wouldn't. This has caused some issues when adding new fields to these enums.
|
||||
# This migration manually changes the enum types to ensure that nobody uses native enums.
|
||||
op.alter_column("query_event", "selected_search_flow", type_=String)
|
||||
op.alter_column("query_event", "feedback", type_=String)
|
||||
op.alter_column("document_retrieval_feedback", "feedback", type_=String)
|
||||
op.execute("DROP TYPE IF EXISTS searchtype")
|
||||
op.execute("DROP TYPE IF EXISTS qafeedbacktype")
|
||||
op.execute("DROP TYPE IF EXISTS searchfeedbacktype")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't want Native Enums, do nothing
|
||||
pass
|
||||
48
backend/alembic/versions/78dbe7e38469_task_tracking.py
Normal file
48
backend/alembic/versions/78dbe7e38469_task_tracking.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Task Tracking
|
||||
|
||||
Revision ID: 78dbe7e38469
|
||||
Revises: 7ccea01261f6
|
||||
Create Date: 2023-10-15 23:40:50.593262
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "78dbe7e38469"
|
||||
down_revision = "7ccea01261f6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"task_queue_jobs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("task_id", sa.String(), nullable=False),
|
||||
sa.Column("task_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"PENDING",
|
||||
"STARTED",
|
||||
"SUCCESS",
|
||||
"FAILURE",
|
||||
name="taskstatus",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"register_time",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("task_queue_jobs")
|
||||
48
backend/alembic/versions/79acd316403a_add_api_key_table.py
Normal file
48
backend/alembic/versions/79acd316403a_add_api_key_table.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Add api_key table
|
||||
|
||||
Revision ID: 79acd316403a
|
||||
Revises: 904e5138fffb
|
||||
Create Date: 2024-01-11 17:56:37.934381
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "79acd316403a"
|
||||
down_revision = "904e5138fffb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"api_key",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("hashed_api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_key_display", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"owner_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("api_key_display"),
|
||||
sa.UniqueConstraint("hashed_api_key"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("api_key")
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Store Chat Retrieval Docs
|
||||
|
||||
Revision ID: 7ccea01261f6
|
||||
Revises: a570b80a5f20
|
||||
Create Date: 2023-10-15 10:39:23.317453
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7ccea01261f6"
|
||||
down_revision = "a570b80a5f20"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"reference_docs",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "reference_docs")
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Add description to persona
|
||||
|
||||
Revision ID: 7da0ae5ad583
|
||||
Revises: e86866a9c78a
|
||||
Create Date: 2023-11-27 00:16:19.959414
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7da0ae5ad583"
|
||||
down_revision = "e86866a9c78a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("description", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "description")
|
||||
26
backend/alembic/versions/7f726bad5367_slack_followup.py
Normal file
26
backend/alembic/versions/7f726bad5367_slack_followup.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Slack Followup
|
||||
|
||||
Revision ID: 7f726bad5367
|
||||
Revises: 79acd316403a
|
||||
Create Date: 2024-01-15 00:19:55.991224
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7f726bad5367"
|
||||
down_revision = "79acd316403a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_feedback",
|
||||
sa.Column("required_followup", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_feedback", "required_followup")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Add index for getting documents just by connector id / credential id
|
||||
|
||||
Revision ID: 7f99be1cb9f5
|
||||
Revises: 78dbe7e38469
|
||||
Create Date: 2023-10-15 22:48:15.487762
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7f99be1cb9f5"
|
||||
down_revision = "78dbe7e38469"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
op.f(
|
||||
"ix_document_by_connector_credential_pair_pkey__connector_id__credential_id"
|
||||
),
|
||||
"document_by_connector_credential_pair",
|
||||
["connector_id", "credential_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
op.f(
|
||||
"ix_document_by_connector_credential_pair_pkey__connector_id__credential_id"
|
||||
),
|
||||
table_name="document_by_connector_credential_pair",
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Add chat session to query_event
|
||||
|
||||
Revision ID: 80696cf850ae
|
||||
Revises: 15326fcec57e
|
||||
Create Date: 2023-11-26 02:38:35.008070
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "80696cf850ae"
|
||||
down_revision = "15326fcec57e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"query_event",
|
||||
sa.Column("chat_session_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_query_event_chat_session_id",
|
||||
"query_event",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"fk_query_event_chat_session_id", "query_event", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("query_event", "chat_session_id")
|
||||
@@ -0,0 +1,34 @@
|
||||
"""Add is_visible to Persona
|
||||
|
||||
Revision ID: 891cd83c87a8
|
||||
Revises: 76b60d407dfb
|
||||
Create Date: 2023-12-21 11:55:54.132279
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "891cd83c87a8"
|
||||
down_revision = "76b60d407dfb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("is_visible", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute("UPDATE persona SET is_visible = true")
|
||||
op.alter_column("persona", "is_visible", nullable=False)
|
||||
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "is_visible")
|
||||
op.drop_column("persona", "display_priority")
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Add full exception stack trace
|
||||
|
||||
Revision ID: 8987770549c0
|
||||
Revises: ec3ec2eabf7b
|
||||
Create Date: 2024-02-10 19:31:28.339135
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8987770549c0"
|
||||
down_revision = "ec3ec2eabf7b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("full_exception_trace", sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("index_attempt", "full_exception_trace")
|
||||
32
backend/alembic/versions/904451035c9b_store_tool_details.py
Normal file
32
backend/alembic/versions/904451035c9b_store_tool_details.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Store Tool Details
|
||||
|
||||
Revision ID: 904451035c9b
|
||||
Revises: 3b25685ff73c
|
||||
Create Date: 2023-10-05 12:29:26.620000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "904451035c9b"
|
||||
down_revision = "3b25685ff73c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("tools", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
)
|
||||
op.drop_column("persona", "tools_text")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("tools_text", sa.TEXT(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.drop_column("persona", "tools")
|
||||
61
backend/alembic/versions/904e5138fffb_tags.py
Normal file
61
backend/alembic/versions/904e5138fffb_tags.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Tags
|
||||
|
||||
Revision ID: 904e5138fffb
|
||||
Revises: 891cd83c87a8
|
||||
Create Date: 2024-01-01 10:44:43.733974
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "904e5138fffb"
|
||||
down_revision = "891cd83c87a8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"tag",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("tag_key", sa.String(), nullable=False),
|
||||
sa.Column("tag_value", sa.String(), nullable=False),
|
||||
sa.Column("source", sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"tag_key", "tag_value", "source", name="_tag_key_value_source_uc"
|
||||
),
|
||||
)
|
||||
op.create_table(
|
||||
"document__tag",
|
||||
sa.Column("document_id", sa.String(), nullable=False),
|
||||
sa.Column("tag_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"],
|
||||
["document.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["tag_id"],
|
||||
["tag.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("document_id", "tag_id"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"search_doc",
|
||||
sa.Column(
|
||||
"doc_metadata",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute("UPDATE search_doc SET doc_metadata = '{}' WHERE doc_metadata IS NULL")
|
||||
op.alter_column("search_doc", "doc_metadata", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("document__tag")
|
||||
op.drop_table("tag")
|
||||
op.drop_column("search_doc", "doc_metadata")
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Added retrieved docs to query event
|
||||
|
||||
Revision ID: 9d97fecfab7f
|
||||
Revises: ffc707a226b4
|
||||
Create Date: 2023-10-20 12:22:31.930449
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9d97fecfab7f"
|
||||
down_revision = "ffc707a226b4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"query_event",
|
||||
sa.Column(
|
||||
"retrieved_document_ids",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("query_event", "retrieved_document_ids")
|
||||
67
backend/alembic/versions/a570b80a5f20_usergroup_tables.py
Normal file
67
backend/alembic/versions/a570b80a5f20_usergroup_tables.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""UserGroup tables
|
||||
|
||||
Revision ID: a570b80a5f20
|
||||
Revises: 904451035c9b
|
||||
Create Date: 2023-10-02 12:27:10.265725
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a570b80a5f20"
|
||||
down_revision = "904451035c9b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"user_group",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("is_up_to_date", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_up_for_deletion", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
op.create_table(
|
||||
"user__user_group",
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("user_group_id", "user_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"user_group__connector_credential_pair",
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
|
||||
sa.Column("is_current", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["cc_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("user_group_id", "cc_pair_id", "is_current"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("user_group__connector_credential_pair")
|
||||
op.drop_table("user__user_group")
|
||||
op.drop_table("user_group")
|
||||
47
backend/alembic/versions/ae62505e3acc_add_saml_accounts.py
Normal file
47
backend/alembic/versions/ae62505e3acc_add_saml_accounts.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Add SAML Accounts
|
||||
|
||||
Revision ID: ae62505e3acc
|
||||
Revises: 7da543f5672f
|
||||
Create Date: 2023-09-26 16:19:30.933183
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ae62505e3acc"
|
||||
down_revision = "7da543f5672f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"saml",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("encrypted_cookie", sa.Text(), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("encrypted_cookie"),
|
||||
sa.UniqueConstraint("user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("saml")
|
||||
520
backend/alembic/versions/b156fa702355_chat_reworked.py
Normal file
520
backend/alembic/versions/b156fa702355_chat_reworked.py
Normal file
@@ -0,0 +1,520 @@
|
||||
"""Chat Reworked
|
||||
|
||||
Revision ID: b156fa702355
|
||||
Revises: baf71f781b9e
|
||||
Create Date: 2023-12-12 00:57:41.823371
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import ENUM
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b156fa702355"
|
||||
down_revision = "baf71f781b9e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
searchtype_enum = ENUM(
|
||||
"KEYWORD", "SEMANTIC", "HYBRID", name="searchtype", create_type=True
|
||||
)
|
||||
recencybiassetting_enum = ENUM(
|
||||
"FAVOR_RECENT",
|
||||
"BASE_DECAY",
|
||||
"NO_DECAY",
|
||||
"AUTO",
|
||||
name="recencybiassetting",
|
||||
create_type=True,
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
searchtype_enum.create(bind)
|
||||
recencybiassetting_enum.create(bind)
|
||||
|
||||
# This is irrecoverable, whatever
|
||||
op.execute("DELETE FROM chat_feedback")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
|
||||
op.create_table(
|
||||
"search_doc",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=False),
|
||||
sa.Column("chunk_ind", sa.Integer(), nullable=False),
|
||||
sa.Column("semantic_id", sa.String(), nullable=False),
|
||||
sa.Column("link", sa.String(), nullable=True),
|
||||
sa.Column("blurb", sa.String(), nullable=False),
|
||||
sa.Column("boost", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.Enum(DocumentSource, native=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("hidden", sa.Boolean(), nullable=False),
|
||||
sa.Column("score", sa.Float(), nullable=False),
|
||||
sa.Column("match_highlights", postgresql.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("primary_owners", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column("secondary_owners", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"prompt",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=False),
|
||||
sa.Column("system_prompt", sa.Text(), nullable=False),
|
||||
sa.Column("task_prompt", sa.Text(), nullable=False),
|
||||
sa.Column("include_citations", sa.Boolean(), nullable=False),
|
||||
sa.Column("datetime_aware", sa.Boolean(), nullable=False),
|
||||
sa.Column("default_prompt", sa.Boolean(), nullable=False),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"persona__prompt",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("prompt_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["prompt_id"],
|
||||
["prompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "prompt_id"),
|
||||
)
|
||||
|
||||
# Changes to persona first so chat_sessions can have the right persona
|
||||
# The empty persona will be overwritten on server startup
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"search_type",
|
||||
searchtype_enum,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute("UPDATE persona SET search_type = 'HYBRID'")
|
||||
op.alter_column("persona", "search_type", nullable=False)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("llm_relevance_filter", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute("UPDATE persona SET llm_relevance_filter = TRUE")
|
||||
op.alter_column("persona", "llm_relevance_filter", nullable=False)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("llm_filter_extraction", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute("UPDATE persona SET llm_filter_extraction = TRUE")
|
||||
op.alter_column("persona", "llm_filter_extraction", nullable=False)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"recency_bias",
|
||||
recencybiassetting_enum,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute("UPDATE persona SET recency_bias = 'BASE_DECAY'")
|
||||
op.alter_column("persona", "recency_bias", nullable=False)
|
||||
op.alter_column("persona", "description", existing_type=sa.VARCHAR(), nullable=True)
|
||||
op.execute("UPDATE persona SET description = ''")
|
||||
op.alter_column("persona", "description", nullable=False)
|
||||
op.create_foreign_key("persona__user_fk", "persona", "user", ["user_id"], ["id"])
|
||||
op.drop_column("persona", "datetime_aware")
|
||||
op.drop_column("persona", "tools")
|
||||
op.drop_column("persona", "hint_text")
|
||||
op.drop_column("persona", "apply_llm_relevance_filter")
|
||||
op.drop_column("persona", "retrieval_enabled")
|
||||
op.drop_column("persona", "system_text")
|
||||
|
||||
# Need to create a persona row so fk can work
|
||||
result = bind.execute(sa.text("SELECT 1 FROM persona WHERE id = 0"))
|
||||
exists = result.fetchone()
|
||||
if not exists:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, user_id, name, description, search_type, num_chunks,
|
||||
llm_relevance_filter, llm_filter_extraction, recency_bias,
|
||||
llm_model_version_override, default_persona, deleted
|
||||
) VALUES (
|
||||
0, NULL, '', '', 'HYBRID', NULL,
|
||||
TRUE, TRUE, 'BASE_DECAY', NULL, TRUE, FALSE
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
delete_statement = sa.text(
|
||||
"""
|
||||
DELETE FROM persona
|
||||
WHERE name = 'Danswer' AND default_persona = TRUE AND id != 0
|
||||
"""
|
||||
)
|
||||
|
||||
bind.execute(delete_statement)
|
||||
|
||||
op.add_column(
|
||||
"chat_feedback",
|
||||
sa.Column("chat_message_id", sa.Integer(), nullable=False),
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_feedback_chat_message_chat_session_id_chat_message_me_fkey",
|
||||
"chat_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_column("chat_feedback", "chat_message_edit_number")
|
||||
op.drop_column("chat_feedback", "chat_message_chat_session_id")
|
||||
op.drop_column("chat_feedback", "chat_message_message_number")
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.Integer(),
|
||||
primary_key=True,
|
||||
autoincrement=True,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("parent_message", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("latest_child_message", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
|
||||
)
|
||||
op.add_column("chat_message", sa.Column("prompt_id", sa.Integer(), nullable=True))
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("citations", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
)
|
||||
op.add_column("chat_message", sa.Column("error", sa.Text(), nullable=True))
|
||||
op.drop_constraint("fk_chat_message_persona_id", "chat_message", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_message__prompt_fk", "chat_message", "prompt", ["prompt_id"], ["id"]
|
||||
)
|
||||
op.drop_column("chat_message", "parent_edit_number")
|
||||
op.drop_column("chat_message", "persona_id")
|
||||
op.drop_column("chat_message", "reference_docs")
|
||||
op.drop_column("chat_message", "edit_number")
|
||||
op.drop_column("chat_message", "latest")
|
||||
op.drop_column("chat_message", "message_number")
|
||||
op.add_column("chat_session", sa.Column("one_shot", sa.Boolean(), nullable=True))
|
||||
op.execute("UPDATE chat_session SET one_shot = TRUE")
|
||||
op.alter_column("chat_session", "one_shot", nullable=False)
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=True,
|
||||
)
|
||||
op.execute("UPDATE chat_session SET persona_id = 0")
|
||||
op.alter_column("chat_session", "persona_id", nullable=False)
|
||||
op.add_column(
|
||||
"document_retrieval_feedback",
|
||||
sa.Column("chat_message_id", sa.Integer(), nullable=False),
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_retrieval_feedback_qa_event_id_fkey",
|
||||
"document_retrieval_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
op.drop_column("document_retrieval_feedback", "qa_event_id")
|
||||
|
||||
# Relation table must be created after the other tables are correct
|
||||
op.create_table(
|
||||
"chat_message__search_doc",
|
||||
sa.Column("chat_message_id", sa.Integer(), nullable=False),
|
||||
sa.Column("search_doc_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chat_message_id"],
|
||||
["chat_message.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_doc_id"],
|
||||
["search_doc.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("chat_message_id", "search_doc_id"),
|
||||
)
|
||||
|
||||
# Needs to be created after chat_message id field is added
|
||||
op.create_foreign_key(
|
||||
"chat_feedback__chat_message_fk",
|
||||
"chat_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
op.drop_table("query_event")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint("persona__user_fk", "persona", type_="foreignkey")
|
||||
op.drop_constraint("chat_message__prompt_fk", "chat_message", type_="foreignkey")
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("system_text", sa.TEXT(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"retrieval_enabled",
|
||||
sa.BOOLEAN(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute("UPDATE persona SET retrieval_enabled = TRUE")
|
||||
op.alter_column("persona", "retrieval_enabled", nullable=False)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"apply_llm_relevance_filter",
|
||||
sa.BOOLEAN(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("hint_text", sa.TEXT(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"tools",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("datetime_aware", sa.BOOLEAN(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.execute("UPDATE persona SET datetime_aware = TRUE")
|
||||
op.alter_column("persona", "datetime_aware", nullable=False)
|
||||
op.alter_column("persona", "description", existing_type=sa.VARCHAR(), nullable=True)
|
||||
op.drop_column("persona", "recency_bias")
|
||||
op.drop_column("persona", "llm_filter_extraction")
|
||||
op.drop_column("persona", "llm_relevance_filter")
|
||||
op.drop_column("persona", "search_type")
|
||||
op.drop_column("persona", "user_id")
|
||||
op.add_column(
|
||||
"document_retrieval_feedback",
|
||||
sa.Column("qa_event_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
)
|
||||
op.drop_column("document_retrieval_feedback", "chat_message_id")
|
||||
op.alter_column(
|
||||
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
op.drop_column("chat_session", "one_shot")
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"message_number",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("latest", sa.BOOLEAN(), autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"edit_number",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"reference_docs",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("persona_id", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"parent_edit_number",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_persona_id",
|
||||
"chat_message",
|
||||
"persona",
|
||||
["persona_id"],
|
||||
["id"],
|
||||
)
|
||||
op.drop_column("chat_message", "error")
|
||||
op.drop_column("chat_message", "citations")
|
||||
op.drop_column("chat_message", "prompt_id")
|
||||
op.drop_column("chat_message", "rephrased_query")
|
||||
op.drop_column("chat_message", "latest_child_message")
|
||||
op.drop_column("chat_message", "parent_message")
|
||||
op.drop_column("chat_message", "id")
|
||||
op.add_column(
|
||||
"chat_feedback",
|
||||
sa.Column(
|
||||
"chat_message_message_number",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_feedback",
|
||||
sa.Column(
|
||||
"chat_message_chat_session_id",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_feedback",
|
||||
sa.Column(
|
||||
"chat_message_edit_number",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.drop_column("chat_feedback", "chat_message_id")
|
||||
op.create_table(
|
||||
"query_event",
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("query", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"selected_search_flow",
|
||||
sa.VARCHAR(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("llm_answer", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column("feedback", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
postgresql.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"retrieved_document_ids",
|
||||
postgresql.ARRAY(sa.VARCHAR()),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("chat_session_id", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chat_session_id"],
|
||||
["chat_session.id"],
|
||||
name="fk_query_event_chat_session_id",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"], ["user.id"], name="query_event_user_id_fkey"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name="query_event_pkey"),
|
||||
)
|
||||
op.drop_table("chat_message__search_doc")
|
||||
op.drop_table("persona__prompt")
|
||||
op.drop_table("prompt")
|
||||
op.drop_table("search_doc")
|
||||
op.create_unique_constraint(
|
||||
"uq_chat_message_combination",
|
||||
"chat_message",
|
||||
["chat_session_id", "message_number", "edit_number"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_feedback_chat_message_chat_session_id_chat_message_me_fkey",
|
||||
"chat_feedback",
|
||||
"chat_message",
|
||||
[
|
||||
"chat_message_chat_session_id",
|
||||
"chat_message_message_number",
|
||||
"chat_message_edit_number",
|
||||
],
|
||||
["chat_session_id", "message_number", "edit_number"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback_qa_event_id_fkey",
|
||||
"document_retrieval_feedback",
|
||||
"query_event",
|
||||
["qa_event_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS searchtype")
|
||||
op.execute("DROP TYPE IF EXISTS recencybiassetting")
|
||||
op.execute("DROP TYPE IF EXISTS documentsource")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Add llm_model_version_override to Persona
|
||||
|
||||
Revision ID: baf71f781b9e
|
||||
Revises: 50b683a8295c
|
||||
Create Date: 2023-12-06 21:56:50.286158
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "baf71f781b9e"
|
||||
down_revision = "50b683a8295c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("llm_model_version_override", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "llm_model_version_override")
|
||||
@@ -18,7 +18,6 @@ depends_on = None
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("deletion_attempt")
|
||||
sa.Enum(name="deletionstatus").drop(op.get_bind(), checkfirst=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Add Total Docs for Index Attempt
|
||||
|
||||
Revision ID: d61e513bef0a
|
||||
Revises: 46625e4745d4
|
||||
Create Date: 2023-10-27 23:02:43.369964
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d61e513bef0a"
|
||||
down_revision = "46625e4745d4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("new_docs_indexed", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.alter_column(
|
||||
"index_attempt", "num_docs_indexed", new_column_name="total_docs_indexed"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"index_attempt", "total_docs_indexed", new_column_name="num_docs_indexed"
|
||||
)
|
||||
op.drop_column("index_attempt", "new_docs_indexed")
|
||||
112
backend/alembic/versions/dbaa756c2ccf_embedding_models.py
Normal file
112
backend/alembic/versions/dbaa756c2ccf_embedding_models.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Embedding Models
|
||||
|
||||
Revision ID: dbaa756c2ccf
|
||||
Revises: 7f726bad5367
|
||||
Create Date: 2024-01-25 17:12:31.813160
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from danswer.db.models import IndexModelStatus
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dbaa756c2ccf"
|
||||
down_revision = "7f726bad5367"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"embedding_model",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("model_name", sa.String(), nullable=False),
|
||||
sa.Column("model_dim", sa.Integer(), nullable=False),
|
||||
sa.Column("normalize", sa.Boolean(), nullable=False),
|
||||
sa.Column("query_prefix", sa.String(), nullable=False),
|
||||
sa.Column("passage_prefix", sa.String(), nullable=False),
|
||||
sa.Column("index_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(IndexModelStatus, native=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
EmbeddingModel = table(
|
||||
"embedding_model",
|
||||
column("id", Integer),
|
||||
column("model_name", String),
|
||||
column("model_dim", Integer),
|
||||
column("normalize", Boolean),
|
||||
column("query_prefix", String),
|
||||
column("passage_prefix", String),
|
||||
column("index_name", String),
|
||||
column(
|
||||
"status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)
|
||||
),
|
||||
)
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
{
|
||||
"model_name": DOCUMENT_ENCODER_MODEL,
|
||||
"model_dim": DOC_EMBEDDING_DIM,
|
||||
"normalize": NORMALIZE_EMBEDDINGS,
|
||||
"query_prefix": ASYM_QUERY_PREFIX,
|
||||
"passage_prefix": ASYM_PASSAGE_PREFIX,
|
||||
"index_name": "danswer_chunk",
|
||||
"status": IndexModelStatus.PRESENT,
|
||||
}
|
||||
],
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("embedding_model_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE index_attempt SET embedding_model_id=1 WHERE embedding_model_id IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
"embedding_model_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"index_attempt__embedding_model_fk",
|
||||
"index_attempt",
|
||||
"embedding_model",
|
||||
["embedding_model_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_embedding_model_present_unique",
|
||||
"embedding_model",
|
||||
["status"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("status = 'PRESENT'"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_embedding_model_future_unique",
|
||||
"embedding_model",
|
||||
["status"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("status = 'FUTURE'"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("index_attempt", "embedding_model_id")
|
||||
op.drop_table("embedding_model")
|
||||
op.execute("DROP TYPE indexmodelstatus;")
|
||||
@@ -17,7 +17,6 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"document",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
@@ -32,6 +31,7 @@ def upgrade() -> None:
|
||||
"VECTOR",
|
||||
"KEYWORD",
|
||||
name="documentstoretype",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
@@ -55,6 +55,7 @@ def upgrade() -> None:
|
||||
"SUCCESS",
|
||||
"FAILED",
|
||||
name="deletionstatus",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
@@ -101,15 +102,10 @@ def upgrade() -> None:
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", "connector_id", "credential_id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("document_by_connector_credential_pair")
|
||||
op.drop_table("deletion_attempt")
|
||||
op.drop_table("chunk")
|
||||
op.drop_table("document")
|
||||
sa.Enum(name="deletionstatus").drop(op.get_bind(), checkfirst=False)
|
||||
sa.Enum(name="documentstoretype").drop(op.get_bind(), checkfirst=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
44
backend/alembic/versions/e0a68a81d434_add_chat_feedback.py
Normal file
44
backend/alembic/versions/e0a68a81d434_add_chat_feedback.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Add Chat Feedback
|
||||
|
||||
Revision ID: e0a68a81d434
|
||||
Revises: ae62505e3acc
|
||||
Create Date: 2023-10-04 20:22:33.380286
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e0a68a81d434"
|
||||
down_revision = "ae62505e3acc"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chat_feedback",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_message_chat_session_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_message_message_number", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_message_edit_number", sa.Integer(), nullable=False),
|
||||
sa.Column("is_positive", sa.Boolean(), nullable=True),
|
||||
sa.Column("feedback_text", sa.Text(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
[
|
||||
"chat_message_chat_session_id",
|
||||
"chat_message_message_number",
|
||||
"chat_message_edit_number",
|
||||
],
|
||||
[
|
||||
"chat_message.chat_session_id",
|
||||
"chat_message.message_number",
|
||||
"chat_message.edit_number",
|
||||
],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("chat_feedback")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add persona to chat_session
|
||||
|
||||
Revision ID: e86866a9c78a
|
||||
Revises: 80696cf850ae
|
||||
Create Date: 2023-11-26 02:51:47.657357
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e86866a9c78a"
|
||||
down_revision = "80696cf850ae"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("chat_session", sa.Column("persona_id", sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"fk_chat_session_persona_id", "chat_session", "persona", ["persona_id"], ["id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey")
|
||||
op.drop_column("chat_session", "persona_id")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Index From Beginning
|
||||
|
||||
Revision ID: ec3ec2eabf7b
|
||||
Revises: dbaa756c2ccf
|
||||
Create Date: 2024-02-06 22:03:28.098158
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ec3ec2eabf7b"
|
||||
down_revision = "dbaa756c2ccf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("from_beginning", sa.Boolean(), nullable=True)
|
||||
)
|
||||
op.execute("UPDATE index_attempt SET from_beginning = False")
|
||||
op.alter_column("index_attempt", "from_beginning", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("index_attempt", "from_beginning")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Basic Document Metadata
|
||||
|
||||
Revision ID: ffc707a226b4
|
||||
Revises: 30c1d5744104
|
||||
Create Date: 2023-10-18 16:52:25.967592
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ffc707a226b4"
|
||||
down_revision = "30c1d5744104"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("doc_updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("primary_owners", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("secondary_owners", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "secondary_owners")
|
||||
op.drop_column("document", "primary_owners")
|
||||
op.drop_column("document", "doc_updated_at")
|
||||
@@ -0,0 +1,3 @@
|
||||
import os
|
||||
|
||||
__version__ = os.environ.get("DANSWER_VERSION", "") or "0.3-dev"
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
|
||||
db_session: Session,
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
document_access_info = get_acccess_info_for_documents(
|
||||
db_session=db_session,
|
||||
@@ -24,13 +26,37 @@ def _get_access_for_documents(
|
||||
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
db_session: Session | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
if db_session is None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
return _get_access_for_documents(
|
||||
document_ids, cc_pair_to_delete, db_session
|
||||
)
|
||||
"""Fetches all access information for the given documents."""
|
||||
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
|
||||
"danswer.access.access", "_get_access_for_documents"
|
||||
)
|
||||
return versioned_get_access_for_documents_fn(
|
||||
document_ids, db_session, cc_pair_to_delete
|
||||
) # type: ignore
|
||||
|
||||
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
|
||||
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
|
||||
return {PUBLIC_DOC_PAT}
|
||||
|
||||
|
||||
def get_acl_for_user(user: User | None, db_session: Session | None = None) -> set[str]:
|
||||
versioned_acl_for_user_fn = fetch_versioned_implementation(
|
||||
"danswer.access.access", "_get_acl_for_user"
|
||||
)
|
||||
return versioned_acl_for_user_fn(user, db_session) # type: ignore
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import contextlib
|
||||
import os
|
||||
import smtplib
|
||||
import uuid
|
||||
@@ -6,10 +5,13 @@ from collections.abc import AsyncGenerator
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import FastAPIUsers
|
||||
@@ -18,21 +20,18 @@ from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
from pydantic import EmailStr
|
||||
from fastapi_users.openapi import OpenAPIResponseType
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import ENABLE_OAUTH
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import OAUTH_TYPE
|
||||
from danswer.configs.app_configs import OPENID_CONFIG_URL
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@@ -42,23 +41,40 @@ from danswer.configs.app_configs import SMTP_SERVER
|
||||
from danswer.configs.app_configs import SMTP_USER
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
FAKE_USER_EMAIL = "fakeuser@fakedanswermail.com"
|
||||
FAKE_USER_PASS = "foobar"
|
||||
|
||||
USER_WHITELIST_FILE = "/home/danswer_whitelist.txt"
|
||||
_user_whitelist: list[str] | None = None
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
|
||||
raise ValueError(
|
||||
"User must choose a valid user authentication method: "
|
||||
"disabled, basic, or google_oauth"
|
||||
)
|
||||
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def get_user_whitelist() -> list[str]:
|
||||
global _user_whitelist
|
||||
if _user_whitelist is None:
|
||||
@@ -92,13 +108,18 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def send_user_verification_email(user_email: str, token: str) -> None:
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = "Danswer Email Verification"
|
||||
msg["From"] = "no-reply@danswer.dev"
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
link = f"{WEB_DOMAIN}/verify-email?token={token}"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
|
||||
body = MIMEText(f"Click the following link to verify your email address: {link}")
|
||||
msg.attach(body)
|
||||
@@ -163,6 +184,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.info(f"User {user.id} has registered.")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.SIGN_UP,
|
||||
data={"action": "create"},
|
||||
user_id=str(user.id),
|
||||
)
|
||||
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
@@ -204,53 +230,92 @@ auth_backend = AuthenticationBackend(
|
||||
get_strategy=get_database_strategy,
|
||||
)
|
||||
|
||||
oauth_client = None # type: GoogleOAuth2 | OpenID | None
|
||||
if ENABLE_OAUTH:
|
||||
if OAUTH_TYPE == "google":
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
elif OAUTH_TYPE == "openid":
|
||||
oauth_client = OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL)
|
||||
else:
|
||||
raise AssertionError(f"Invalid OAUTH type {OAUTH_TYPE}")
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
def get_logout_router(
|
||||
self,
|
||||
backend: AuthenticationBackend,
|
||||
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Provide a router for logout only for OAuth/OIDC Flows.
|
||||
This way the login router does not need to be included
|
||||
"""
|
||||
router = APIRouter()
|
||||
get_current_user_token = self.authenticator.current_user_token(
|
||||
active=True, verified=requires_verification
|
||||
)
|
||||
logout_responses: OpenAPIResponseType = {
|
||||
**{
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "Missing token or inactive user."
|
||||
}
|
||||
},
|
||||
**backend.transport.get_openapi_logout_responses_success(),
|
||||
}
|
||||
|
||||
@router.post(
|
||||
"/logout", name=f"auth:{backend.name}.logout", responses=logout_responses
|
||||
)
|
||||
async def logout(
|
||||
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
) -> Response:
|
||||
user, token = user_token
|
||||
return await backend.logout(strategy, user, token)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
|
||||
# Currently unused, maybe useful later
|
||||
async def create_get_fake_user() -> User:
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
) # type:ignore
|
||||
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
|
||||
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
|
||||
|
||||
logger.info("Creating fake user due to Auth being turned off")
|
||||
async with get_async_session_context() as session:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
user = await user_manager.get_by_email(FAKE_USER_EMAIL)
|
||||
if user:
|
||||
return user
|
||||
user = await user_manager.create(
|
||||
UserCreate(email=EmailStr(FAKE_USER_EMAIL), password=FAKE_USER_PASS)
|
||||
)
|
||||
logger.info("Created fake user.")
|
||||
return user
|
||||
|
||||
|
||||
current_active_user = fastapi_users.current_user(
|
||||
active=True, verified=REQUIRE_EMAIL_VERIFICATION, optional=DISABLE_AUTH
|
||||
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
get_user_manager, [auth_backend]
|
||||
)
|
||||
|
||||
|
||||
async def current_user(user: User = Depends(current_active_user)) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
# NOTE: verified=REQUIRE_EMAIL_VERIFICATION is not used here since we
|
||||
# take care of that in `double_check_user` ourself. This is needed, since
|
||||
# we want the /me endpoint to still return a user even if they are not
|
||||
# yet verified, so that the frontend knows they exist
|
||||
optional_valid_user = fastapi_users.current_user(active=True, optional=True)
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def current_admin_user(user: User = Depends(current_user)) -> User | None:
|
||||
async def current_user(
|
||||
request: Request,
|
||||
user: User | None = Depends(optional_valid_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> User | None:
|
||||
double_check_user = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "double_check_user"
|
||||
)
|
||||
user = await double_check_user(request, user, db_session)
|
||||
return user
|
||||
|
||||
|
||||
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,29 +1,233 @@
|
||||
from celery import Celery
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from danswer.background.connector_deletion import cleanup_connector_credential_pair
|
||||
from celery import Celery # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.connectors.file.utils import file_age_in_hours
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.document_set import fetch_documents_for_document_set
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import SYNC_DB_API
|
||||
from danswer.document_set.document_set import sync_document_set
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_live_task_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_broker_url = "sqla+" + build_connection_string(db_api=SYNC_DB_API)
|
||||
celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API)
|
||||
connection_string = build_connection_string(db_api=SYNC_DB_API)
|
||||
celery_broker_url = f"sqla+{connection_string}"
|
||||
celery_backend_url = f"db+{connection_string}"
|
||||
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
|
||||
|
||||
|
||||
@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
|
||||
_SYNC_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
#####
|
||||
# Tasks that need to be run in job queue, registered via APIs
|
||||
#
|
||||
# If imports from this module are needed, use local imports to avoid circular importing
|
||||
#####
|
||||
@build_celery_task_wrapper(name_cc_cleanup_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def cleanup_connector_credential_pair_task(
|
||||
connector_id: int, credential_id: int
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> int:
|
||||
return cleanup_connector_credential_pair(connector_id, credential_id)
|
||||
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
|
||||
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
|
||||
or updating the ACL"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
# validate that the connector / credential pair is deletable
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
if not cc_pair or not check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
|
||||
"This is likely because there is an ongoing / planned indexing attempt OR the "
|
||||
"connector is not disabled."
|
||||
)
|
||||
|
||||
try:
|
||||
# The bulk of the work is in here, updates Postgres and Vespa
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
return delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run connector_deletion due to {e}")
|
||||
raise e
|
||||
|
||||
|
||||
@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
|
||||
@build_celery_task_wrapper(name_document_set_sync_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_document_set_task(document_set_id: int) -> None:
|
||||
try:
|
||||
return sync_document_set(document_set_id=document_set_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to sync document set %s", document_set_id)
|
||||
raise
|
||||
"""For document sets marked as not up to date, sync the state from postgres
|
||||
into the datastore. Also handles deletions."""
|
||||
|
||||
def _sync_document_batch(document_ids: list[str]) -> None:
|
||||
logger.debug(f"Syncing document sets for: {document_ids}")
|
||||
# begin a transaction, release lock at the end
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# acquires a lock on the documents so that no other process can modify them
|
||||
prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
|
||||
# get current state of document sets for these documents
|
||||
document_set_map = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
}
|
||||
|
||||
# update Vespa
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=set(document_set_map.get(document_id, [])),
|
||||
)
|
||||
for document_id in document_ids
|
||||
]
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
documents_to_update = fetch_documents_for_document_set(
|
||||
document_set_id=document_set_id,
|
||||
db_session=db_session,
|
||||
current_only=False,
|
||||
)
|
||||
for document_batch in batch_generator(
|
||||
documents_to_update, _SYNC_BATCH_SIZE
|
||||
):
|
||||
_sync_document_batch(
|
||||
document_ids=[document.id for document in document_batch]
|
||||
)
|
||||
|
||||
# if there are no connectors, then delete the document set. Otherwise, just
|
||||
# mark it as successfully synced.
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(
|
||||
db_session=db_session, document_set_id=document_set_id
|
||||
),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if not document_set.connector_credential_pairs:
|
||||
delete_document_set(
|
||||
document_set_row=document_set, db_session=db_session
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(
|
||||
document_set_id=document_set_id, db_session=db_session
|
||||
)
|
||||
logger.info(f"Document set sync for '{document_set_id}' complete!")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to sync document set %s", document_set_id)
|
||||
raise
|
||||
|
||||
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_for_document_sets_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_document_sets_sync_task() -> None:
|
||||
"""Runs periodically to check if any document sets are out of sync
|
||||
Creates a task to sync the set if needed"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
if not document_set.is_up_to_date:
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_live_task_not_timed_out(
|
||||
latest_sync, db_session
|
||||
):
|
||||
logger.info(
|
||||
f"Document set '{document_set.id}' is already syncing. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
sync_document_set_task.apply_async(
|
||||
kwargs=dict(document_set_id=document_set.id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="clean_old_temp_files_task", soft_time_limit=JOB_TIMEOUT)
|
||||
def clean_old_temp_files_task(
|
||||
age_threshold_in_hours: float | int = 24 * 7, # 1 week,
|
||||
base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH,
|
||||
) -> None:
|
||||
"""Files added via the File connector need to be deleted after ingestion
|
||||
Currently handled async of the indexing job"""
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
for file in os.listdir(base_path):
|
||||
full_file_path = Path(base_path) / file
|
||||
if file_age_in_hours(full_file_path) > age_threshold_in_hours:
|
||||
logger.info(f"Cleaning up uploaded file: {full_file_path}")
|
||||
os.remove(full_file_path)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"check-for-document-set-sync": {
|
||||
"task": "check_for_document_sets_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
"clean-old-temp-files": {
|
||||
"task": "clean_old_temp_files_task",
|
||||
"schedule": timedelta(minutes=30),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,37 +1,23 @@
|
||||
import json
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery import celery_app
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
|
||||
|
||||
def get_celery_task(task_id: str) -> AsyncResult:
|
||||
"""NOTE: even if the task doesn't exist, celery will still return something
|
||||
with a `PENDING` state"""
|
||||
return AsyncResult(task_id, backend=celery_app.backend)
|
||||
def get_deletion_status(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
cleanup_task_name = name_cc_cleanup_task(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
)
|
||||
task_state = get_latest_task(task_name=cleanup_task_name, db_session=db_session)
|
||||
|
||||
if not task_state:
|
||||
return None
|
||||
|
||||
def get_celery_task_status(task_id: str) -> str | None:
|
||||
"""NOTE: is tightly coupled to the internals of kombu (which is the
|
||||
translation layer to allow us to use Postgres as a broker). If we change
|
||||
the broker, this will need to be updated.
|
||||
|
||||
This should not be called on any critical flows.
|
||||
"""
|
||||
# first check for any pending tasks
|
||||
with Session(get_sqlalchemy_engine()) as session:
|
||||
rows = session.execute(text("SELECT payload FROM kombu_message WHERE visible"))
|
||||
for row in rows:
|
||||
payload = json.loads(row[0])
|
||||
if payload["headers"]["id"] == task_id:
|
||||
return "PENDING"
|
||||
|
||||
task = get_celery_task(task_id)
|
||||
# if not pending, then we know the task really exists
|
||||
if task.status != "PENDING":
|
||||
return task.status
|
||||
|
||||
return None
|
||||
return DeletionAttemptSnapshot(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
status=task_state.status,
|
||||
)
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.background.celery.celery_utils import get_celery_task
|
||||
from danswer.background.celery.celery_utils import get_celery_task_status
|
||||
from danswer.background.connector_deletion import get_cleanup_task_id
|
||||
from danswer.db.models import DeletionStatus
|
||||
from danswer.server.models import DeletionAttemptSnapshot
|
||||
|
||||
|
||||
def get_deletion_status(
|
||||
connector_id: int, credential_id: int
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
cleanup_task_id = get_cleanup_task_id(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
)
|
||||
deletion_task = get_celery_task(task_id=cleanup_task_id)
|
||||
deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
|
||||
|
||||
deletion_status = None
|
||||
error_msg = None
|
||||
num_docs_deleted = 0
|
||||
if deletion_task_status == "SUCCESS":
|
||||
deletion_status = DeletionStatus.SUCCESS
|
||||
num_docs_deleted = cast(int, deletion_task.get(propagate=False))
|
||||
elif deletion_task_status == "FAILURE":
|
||||
deletion_status = DeletionStatus.FAILED
|
||||
error_msg = deletion_task.get(propagate=False)
|
||||
elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
|
||||
deletion_status = DeletionStatus.IN_PROGRESS
|
||||
|
||||
return (
|
||||
DeletionAttemptSnapshot(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
status=deletion_status,
|
||||
error_msg=str(error_msg),
|
||||
num_docs_deleted=num_docs_deleted,
|
||||
)
|
||||
if deletion_status
|
||||
else None
|
||||
)
|
||||
@@ -10,24 +10,30 @@ are multiple connector / credential pairs that have indexed it
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import delete_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair
|
||||
from danswer.db.document import delete_documents_complete
|
||||
from danswer.db.document import get_document_connector_cnts
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.document_set import (
|
||||
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit,
|
||||
)
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -55,7 +61,9 @@ def _delete_connector_credential_pair_batch(
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt == 1
|
||||
]
|
||||
logger.debug(f"Deleting documents: {document_ids_to_delete}")
|
||||
|
||||
document_index.delete(doc_ids=document_ids_to_delete)
|
||||
|
||||
delete_documents_complete(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_delete,
|
||||
@@ -81,7 +89,9 @@ def _delete_connector_credential_pair_batch(
|
||||
for document_id, access in access_for_documents.items()
|
||||
]
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
delete_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
@@ -93,12 +103,56 @@ def _delete_connector_credential_pair_batch(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _delete_connector_credential_pair(
|
||||
def cleanup_synced_entities(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> None:
|
||||
"""Updates the document sets associated with the connector / credential pair,
|
||||
then relies on the document set sync script to kick off Celery jobs which will
|
||||
sync these updates to Vespa.
|
||||
|
||||
Waits until the document sets are synced before returning."""
|
||||
logger.info(f"Cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'")
|
||||
document_sets_ids_to_sync = list(
|
||||
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# wait till all document sets are synced before continuing
|
||||
while True:
|
||||
all_synced = True
|
||||
document_sets = get_document_sets_by_ids(
|
||||
db_session=db_session, document_set_ids=document_sets_ids_to_sync
|
||||
)
|
||||
for document_set in document_sets:
|
||||
if not document_set.is_up_to_date:
|
||||
all_synced = False
|
||||
|
||||
if all_synced:
|
||||
break
|
||||
|
||||
# wait for 30 seconds before checking again
|
||||
db_session.commit() # end transaction
|
||||
logger.info(
|
||||
f"Document sets '{document_sets_ids_to_sync}' not synced yet, waiting 30s"
|
||||
)
|
||||
time.sleep(30)
|
||||
|
||||
logger.info(
|
||||
f"Finished cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'"
|
||||
)
|
||||
|
||||
|
||||
def delete_connector_credential_pair(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> int:
|
||||
connector_id = cc_pair.connector_id
|
||||
credential_id = cc_pair.credential_id
|
||||
|
||||
num_docs_deleted = 0
|
||||
while True:
|
||||
documents = get_documents_for_connector_credential_pair(
|
||||
@@ -118,13 +172,18 @@ def _delete_connector_credential_pair(
|
||||
)
|
||||
num_docs_deleted += len(documents)
|
||||
|
||||
# cleanup everything else up
|
||||
# Clean up document sets / access information from Postgres
|
||||
# and sync these updates to Vespa
|
||||
# TODO: add user group cleanup with `fetch_versioned_implementation`
|
||||
cleanup_synced_entities(cc_pair, db_session)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
delete_connector_credential_pair(
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
@@ -144,37 +203,3 @@ def _delete_connector_credential_pair(
|
||||
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
|
||||
)
|
||||
return num_docs_deleted
|
||||
|
||||
|
||||
def cleanup_connector_credential_pair(connector_id: int, credential_id: int) -> int:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
# validate that the connector / credential pair is deletable
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
if not cc_pair or not check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
|
||||
"This is likely because there is an ongoing / planned indexing attempt OR the "
|
||||
"connector is not disabled."
|
||||
)
|
||||
|
||||
try:
|
||||
return _delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_index=get_default_document_index(),
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run connector_deletion due to {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def get_cleanup_task_id(connector_id: int, credential_id: int) -> str:
|
||||
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery import sync_document_set_task
|
||||
from danswer.background.utils import interval_run_job
|
||||
from danswer.db.document_set import (
|
||||
fetch_document_sets,
|
||||
)
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_ExistingTaskCache: dict[int, AsyncResult] = {}
|
||||
|
||||
|
||||
def _document_sync_loop() -> None:
|
||||
# cleanup tasks
|
||||
existing_tasks = list(_ExistingTaskCache.items())
|
||||
for document_set_id, task in existing_tasks:
|
||||
if task.ready():
|
||||
logger.info(
|
||||
f"Document set '{document_set_id}' is complete with status "
|
||||
f"{task.status}. Cleaning up."
|
||||
)
|
||||
del _ExistingTaskCache[document_set_id]
|
||||
|
||||
# kick off new tasks
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(db_session=db_session)
|
||||
for document_set, _ in document_set_info:
|
||||
if not document_set.is_up_to_date:
|
||||
if document_set.id in _ExistingTaskCache:
|
||||
logger.info(
|
||||
f"Document set '{document_set.id}' is already syncing. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Document set {document_set.id} is not synced. Syncing now!"
|
||||
)
|
||||
task = sync_document_set_task.apply_async(
|
||||
kwargs=dict(document_set_id=document_set.id),
|
||||
)
|
||||
_ExistingTaskCache[document_set.id] = task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interval_run_job(
|
||||
job=_document_sync_loop, delay=5, emit_job_start_log=False
|
||||
) # run every 5 seconds
|
||||
@@ -1,6 +0,0 @@
|
||||
from danswer.background.utils import interval_run_job
|
||||
from danswer.connectors.file.utils import clean_old_temp_files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interval_run_job(clean_old_temp_files, 30 * 60) # run every 30 minutes
|
||||
80
backend/danswer/background/indexing/checkpointing.py
Normal file
80
backend/danswer/background/indexing/checkpointing.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Experimental functionality related to splitting up indexing
|
||||
into a series of checkpoints to better handle intermittent failures
|
||||
/ jobs being killed by cloud providers."""
|
||||
import datetime
|
||||
|
||||
from danswer.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
|
||||
|
||||
|
||||
def _2010_dt() -> datetime.datetime:
|
||||
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def _2020_dt() -> datetime.datetime:
|
||||
return datetime.datetime(year=2020, month=1, day=1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def _default_end_time(
|
||||
last_successful_run: datetime.datetime | None,
|
||||
) -> datetime.datetime:
|
||||
"""If year is before 2010, go to the beginning of 2010.
|
||||
If year is 2010-2020, go in 5 year increments.
|
||||
If year > 2020, then go in 180 day increments.
|
||||
|
||||
For connectors that don't support a `filter_by` and instead rely on `sort_by`
|
||||
for polling, then this will cause a massive duplication of fetches. For these
|
||||
connectors, you may want to override this function to return a more reasonable
|
||||
plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher)."""
|
||||
last_successful_run = (
|
||||
datetime_to_utc(last_successful_run) if last_successful_run else None
|
||||
)
|
||||
if last_successful_run is None or last_successful_run < _2010_dt():
|
||||
return _2010_dt()
|
||||
|
||||
if last_successful_run < _2020_dt():
|
||||
return min(last_successful_run + datetime.timedelta(days=365 * 5), _2020_dt())
|
||||
|
||||
return last_successful_run + datetime.timedelta(days=180)
|
||||
|
||||
|
||||
def find_end_time_for_indexing_attempt(
|
||||
last_successful_run: datetime.datetime | None,
|
||||
# source_type can be used to override the default for certain connectors, currently unused
|
||||
source_type: DocumentSource,
|
||||
) -> datetime.datetime | None:
|
||||
"""Is the current time unless the connector is run over a large period, in which case it is
|
||||
split up into large time segments that become smaller as it approaches the present
|
||||
"""
|
||||
# NOTE: source_type can be used to override the default for certain connectors
|
||||
end_of_window = _default_end_time(last_successful_run)
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
if end_of_window < now:
|
||||
return end_of_window
|
||||
|
||||
# None signals that we should index up to current time
|
||||
return None
|
||||
|
||||
|
||||
def get_time_windows_for_index_attempt(
|
||||
last_successful_run: datetime.datetime, source_type: DocumentSource
|
||||
) -> list[tuple[datetime.datetime, datetime.datetime]]:
|
||||
if not EXPERIMENTAL_CHECKPOINTING_ENABLED:
|
||||
return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))]
|
||||
|
||||
time_windows: list[tuple[datetime.datetime, datetime.datetime]] = []
|
||||
start_of_window: datetime.datetime | None = last_successful_run
|
||||
while start_of_window:
|
||||
end_of_window = find_end_time_for_indexing_attempt(
|
||||
last_successful_run=start_of_window, source_type=source_type
|
||||
)
|
||||
time_windows.append(
|
||||
(
|
||||
start_of_window,
|
||||
end_of_window or datetime.datetime.now(tz=datetime.timezone.utc),
|
||||
)
|
||||
)
|
||||
start_of_window = end_of_window
|
||||
|
||||
return time_windows
|
||||
33
backend/danswer/background/indexing/dask_utils.py
Normal file
33
backend/danswer/background/indexing/dask_utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import asyncio
|
||||
|
||||
import psutil
|
||||
from dask.distributed import WorkerPlugin
|
||||
from distributed import Worker
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ResourceLogger(WorkerPlugin):
|
||||
def __init__(self, log_interval: int = 60 * 5):
|
||||
self.log_interval = log_interval
|
||||
|
||||
def setup(self, worker: Worker) -> None:
|
||||
"""This method will be called when the plugin is attached to a worker."""
|
||||
self.worker = worker
|
||||
worker.loop.add_callback(self.log_resources)
|
||||
|
||||
async def log_resources(self) -> None:
|
||||
"""Periodically log CPU and memory usage.
|
||||
|
||||
NOTE: must be async or else will clog up the worker indefinitely due to the fact that
|
||||
Dask uses Tornado under the hood (which is async)"""
|
||||
while True:
|
||||
cpu_percent = psutil.cpu_percent(interval=None)
|
||||
memory_available_gb = psutil.virtual_memory().available / (1024.0**3)
|
||||
# You can now log these values or send them to a monitoring service
|
||||
logger.debug(
|
||||
f"Worker {self.worker.address}: CPU usage {cpu_percent}%, Memory available {memory_available_gb}GB"
|
||||
)
|
||||
await asyncio.sleep(self.log_interval)
|
||||
104
backend/danswer/background/indexing/job_client.py
Normal file
104
backend/danswer/background/indexing/job_client.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Custom client that works similarly to Dask, but simpler and more lightweight.
|
||||
Dask jobs behaved very strangely - they would die all the time, retries would
|
||||
not follow the expected behavior, etc.
|
||||
|
||||
NOTE: cannot use Celery directly due to
|
||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from torch import multiprocessing
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
JobStatusType = (
|
||||
Literal["error"]
|
||||
| Literal["finished"]
|
||||
| Literal["pending"]
|
||||
| Literal["running"]
|
||||
| Literal["cancelled"]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleJob:
|
||||
"""Drop in replacement for `dask.distributed.Future`"""
|
||||
|
||||
id: int
|
||||
process: multiprocessing.Process | None = None
|
||||
|
||||
def cancel(self) -> bool:
|
||||
return self.release()
|
||||
|
||||
def release(self) -> bool:
|
||||
if self.process is not None and self.process.is_alive():
|
||||
self.process.terminate()
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def status(self) -> JobStatusType:
|
||||
if not self.process:
|
||||
return "pending"
|
||||
elif self.process.is_alive():
|
||||
return "running"
|
||||
elif self.process.exitcode is None:
|
||||
return "cancelled"
|
||||
elif self.process.exitcode > 0:
|
||||
return "error"
|
||||
else:
|
||||
return "finished"
|
||||
|
||||
def done(self) -> bool:
|
||||
return (
|
||||
self.status == "finished"
|
||||
or self.status == "cancelled"
|
||||
or self.status == "error"
|
||||
)
|
||||
|
||||
def exception(self) -> str:
|
||||
"""Needed to match the Dask API, but not implemented since we don't currently
|
||||
have a way to get back the exception information from the child process."""
|
||||
return (
|
||||
f"Job with ID '{self.id}' was killed or encountered an unhandled exception."
|
||||
)
|
||||
|
||||
|
||||
class SimpleJobClient:
|
||||
"""Drop in replacement for `dask.distributed.Client`"""
|
||||
|
||||
def __init__(self, n_workers: int = 1) -> None:
|
||||
self.n_workers = n_workers
|
||||
self.job_id_counter = 0
|
||||
self.jobs: dict[int, SimpleJob] = {}
|
||||
|
||||
def _cleanup_completed_jobs(self) -> None:
|
||||
current_job_ids = list(self.jobs.keys())
|
||||
for job_id in current_job_ids:
|
||||
job = self.jobs.get(job_id)
|
||||
if job and job.done():
|
||||
logger.debug(f"Cleaning up job with id: '{job.id}'")
|
||||
del self.jobs[job.id]
|
||||
|
||||
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
|
||||
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
||||
self._cleanup_completed_jobs()
|
||||
if len(self.jobs) >= self.n_workers:
|
||||
logger.debug("No available workers to run job")
|
||||
return None
|
||||
|
||||
job_id = self.job_id_counter
|
||||
self.job_id_counter += 1
|
||||
|
||||
process = multiprocessing.Process(target=func, args=args, daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
self.jobs[job_id] = job
|
||||
|
||||
return job
|
||||
323
backend/danswer/background/indexing/run_indexing.py
Normal file
323
backend/danswer/background/indexing/run_indexing.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import torch
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_document_generator(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""NOTE: `start_time` and `end_time` are only used for poll connectors"""
|
||||
task = attempt.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector, new_credential_json = instantiate_connector(
|
||||
attempt.connector.source,
|
||||
task,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential.credential_json,
|
||||
)
|
||||
if new_credential_json is not None:
|
||||
backend_update_credential_json(
|
||||
attempt.credential, new_credential_json, db_session
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
raise e
|
||||
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
)
|
||||
|
||||
logger.info(f"Polling for updates between {start_time} and {end_time}")
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time.timestamp(), end=end_time.timestamp()
|
||||
)
|
||||
|
||||
else:
|
||||
# Event types cannot be handled by a background type
|
||||
raise RuntimeError(f"Invalid task type: {task}")
|
||||
|
||||
return doc_batch_generator
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
db_embedding_model = index_attempt.embedding_model
|
||||
index_name = db_embedding_model.index_name
|
||||
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
|
||||
# Mark as started
|
||||
mark_attempt_in_progress(index_attempt, db_session)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
)
|
||||
|
||||
# Indexing is only done into one index at a time
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (db_embedding_model.status == IndexModelStatus.FUTURE),
|
||||
)
|
||||
|
||||
db_connector = index_attempt.connector
|
||||
db_credential = index_attempt.credential
|
||||
last_successful_index_time = (
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
run_end_dt = None
|
||||
for ind, (window_start, window_end) in enumerate(
|
||||
get_time_windows_for_index_attempt(
|
||||
last_successful_run=datetime.fromtimestamp(
|
||||
last_successful_index_time, tz=timezone.utc
|
||||
),
|
||||
source_type=db_connector.source,
|
||||
)
|
||||
):
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
doc_batch_generator = _get_document_generator(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
)
|
||||
|
||||
try:
|
||||
for doc_batch in doc_batch_generator:
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
db_session.refresh(db_connector)
|
||||
if (
|
||||
db_connector.disabled
|
||||
and db_embedding_model.status != IndexModelStatus.FUTURE
|
||||
):
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
db_session.refresh(index_attempt)
|
||||
if index_attempt.status != IndexingStatus.IN_PROGRESS:
|
||||
raise RuntimeError("Index Attempt was canceled")
|
||||
|
||||
logger.debug(
|
||||
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
|
||||
)
|
||||
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
documents=doc_batch,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
),
|
||||
)
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
# of the transactions when computing `NOW()`, so if we have
|
||||
# a long running transaction, the `time_updated` field will
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
)
|
||||
|
||||
run_end_dt = window_end
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or db_connector.disabled
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
break
|
||||
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.SUCCESS,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexed or refreshed {document_count} total documents for a total of {chunk_count} indexed chunks"
|
||||
)
|
||||
logger.info(
|
||||
f"Connector successfully finished, elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
|
||||
|
||||
def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
try:
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
|
||||
logger.info(f"Setting task to use {num_threads} threads")
|
||||
torch.set_num_threads(num_threads)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
if attempt is None:
|
||||
raise RuntimeError(
|
||||
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Running indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(
|
||||
db_session=db_session,
|
||||
index_attempt=attempt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
117
backend/danswer/background/task_utils.py
Normal file
117
backend/danswer/background/task_utils.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from celery import Task
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.tasks import mark_task_finished
|
||||
from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
|
||||
|
||||
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
|
||||
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
|
||||
|
||||
def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
return f"sync_doc_set_{document_set_id}"
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
def build_run_wrapper(build_name_fn: Callable[..., str]) -> Callable[[T], T]:
|
||||
"""Utility meant to wrap the celery task `run` function in order to
|
||||
automatically update our custom `task_queue_jobs` table appropriately"""
|
||||
|
||||
def wrap_task_fn(task_fn: T) -> T:
|
||||
@wraps(task_fn)
|
||||
def wrapped_task_fn(*args: list, **kwargs: dict) -> Any:
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
task_name = build_name_fn(*args, **kwargs)
|
||||
with Session(engine) as db_session:
|
||||
# mark the task as started
|
||||
mark_task_start(task_name=task_name, db_session=db_session)
|
||||
|
||||
result = None
|
||||
exception = None
|
||||
try:
|
||||
result = task_fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
with Session(engine) as db_session:
|
||||
mark_task_finished(
|
||||
task_name=task_name,
|
||||
db_session=db_session,
|
||||
success=exception is None,
|
||||
)
|
||||
|
||||
if not exception:
|
||||
return result
|
||||
else:
|
||||
raise exception
|
||||
|
||||
return cast(T, wrapped_task_fn)
|
||||
|
||||
return wrap_task_fn
|
||||
|
||||
|
||||
# rough type signature for `apply_async`
|
||||
AA = TypeVar("AA", bound=Callable[..., AsyncResult])
|
||||
|
||||
|
||||
def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA], AA]:
|
||||
"""Utility meant to wrap celery `apply_async` function in order to automatically
|
||||
update create an entry in our `task_queue_jobs` table"""
|
||||
|
||||
def wrapper(fn: AA) -> AA:
|
||||
@wraps(fn)
|
||||
def wrapped_fn(
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
*other_args: list,
|
||||
**other_kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
# `apply_async` takes in args / kwargs directly as arguments
|
||||
args_for_build_name = args or tuple()
|
||||
kwargs_for_build_name = kwargs or {}
|
||||
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# mark the task as started
|
||||
task = fn(args, kwargs, *other_args, **other_kwargs)
|
||||
register_task(task.id, task_name, db_session)
|
||||
|
||||
return task
|
||||
|
||||
return cast(AA, wrapped_fn)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def build_celery_task_wrapper(
|
||||
build_name_fn: Callable[..., str]
|
||||
) -> Callable[[Task], Task]:
|
||||
"""Utility meant to wrap celery task functions in order to automatically
|
||||
update our custom `task_queue_jobs` table appropriately.
|
||||
|
||||
On task creation (e.g. `apply_async`), a row is inserted into the table with
|
||||
status `PENDING`.
|
||||
On task start, the latest row is updated to have status `STARTED`.
|
||||
On task success, the latest row is updated to have status `SUCCESS`.
|
||||
On the task raising an unhandled exception, the latest row is updated to have
|
||||
status `FAILURE`.
|
||||
"""
|
||||
|
||||
def wrap_task(task: Task) -> Task:
|
||||
task.run = build_run_wrapper(build_name_fn)(task.run) # type: ignore
|
||||
task.apply_async = build_apply_async_wrapper(build_name_fn)(task.apply_async) # type: ignore
|
||||
return task
|
||||
|
||||
return wrap_task
|
||||
@@ -1,65 +1,108 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import dask
|
||||
import torch
|
||||
from dask.distributed import Client
|
||||
from dask.distributed import Future
|
||||
from distributed import LocalCluster
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
from danswer.background.indexing.job_client import SimpleJob
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import LOG_LEVEL
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.datastores.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.embedding_model import update_embedding_model_status
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.search.search_utils import warm_up_models
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If the indexing dies, it's most likely due to resource constraints,
|
||||
# restarting just delays the eventual failure, not useful to the user
|
||||
dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
||||
|
||||
_UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
"Stopped mid run, likely due to the background process being killed"
|
||||
)
|
||||
|
||||
|
||||
def should_create_new_indexing(
|
||||
connector: Connector, last_index: IndexAttempt | None, db_session: Session
|
||||
"""Util funcs"""
|
||||
|
||||
|
||||
def _get_num_threads() -> int:
|
||||
"""Get # of "threads" to use for ML models in an indexing job. By default uses
|
||||
the torch implementation, which returns the # of physical cores on the machine.
|
||||
"""
|
||||
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
|
||||
|
||||
|
||||
def _should_create_new_indexing(
|
||||
connector: Connector,
|
||||
last_index: IndexAttempt | None,
|
||||
model: EmbeddingModel,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
# When switching over models, always index at least once
|
||||
if model.status == IndexModelStatus.FUTURE and not last_index:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
# Only one scheduled job per connector at a time
|
||||
# Can schedule another one if the current one is already running however
|
||||
# Because the currently running one will not be until the latest time
|
||||
# Note, this last index is for the given embedding model
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
|
||||
|
||||
def mark_run_failed(
|
||||
def _is_indexing_job_marked_as_finished(index_attempt: IndexAttempt | None) -> bool:
|
||||
if index_attempt is None:
|
||||
return False
|
||||
|
||||
return (
|
||||
index_attempt.status == IndexingStatus.FAILED
|
||||
or index_attempt.status == IndexingStatus.SUCCESS
|
||||
)
|
||||
|
||||
|
||||
def _mark_run_failed(
|
||||
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
) -> None:
|
||||
"""Marks the `index_attempt` row as failed + updates the `
|
||||
@@ -76,6 +119,7 @@ def mark_run_failed(
|
||||
if (
|
||||
index_attempt.connector_id is not None
|
||||
and index_attempt.credential_id is not None
|
||||
and index_attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
):
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -85,366 +129,283 @@ def mark_run_failed(
|
||||
)
|
||||
|
||||
|
||||
def create_indexing_jobs(db_session: Session, existing_jobs: dict[int, Future]) -> None:
|
||||
"""Main funcs"""
|
||||
|
||||
|
||||
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
"""Creates new indexing jobs for each connector / credential pair which is:
|
||||
1. Enabled
|
||||
2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
ongoing_pairs: set[tuple[int | None, int | None]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(db_session=db_session, index_attempt_id=attempt_id)
|
||||
if attempt is None:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
||||
"indexing jobs"
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
ongoing: set[tuple[int | None, int | None, int]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
continue
|
||||
ongoing_pairs.add((attempt.connector_id, attempt.credential_id))
|
||||
|
||||
enabled_connectors = fetch_connectors(db_session, disabled_status=False)
|
||||
for connector in enabled_connectors:
|
||||
for association in connector.credentials:
|
||||
credential = association.credential
|
||||
|
||||
# check if there is an ogoing indexing attempt for this connector + credential pair
|
||||
if (connector.id, credential.id) in ongoing_pairs:
|
||||
if attempt is None:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
||||
"indexing jobs"
|
||||
)
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt(connector.id, credential.id, db_session)
|
||||
if not should_create_new_indexing(connector, last_attempt, db_session):
|
||||
continue
|
||||
create_index_attempt(connector.id, credential.id, db_session)
|
||||
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
attempt_status=IndexingStatus.NOT_STARTED,
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_id,
|
||||
attempt.credential_id,
|
||||
attempt.embedding_model_id,
|
||||
)
|
||||
)
|
||||
|
||||
embedding_models = [get_current_db_embedding_model(db_session)]
|
||||
secondary_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
if secondary_embedding_model is not None:
|
||||
embedding_models.append(secondary_embedding_model)
|
||||
|
||||
all_connectors = fetch_connectors(db_session)
|
||||
for connector in all_connectors:
|
||||
for association in connector.credentials:
|
||||
for model in embedding_models:
|
||||
credential = association.credential
|
||||
|
||||
# Check if there is an ongoing indexing attempt for this connector + credential pair
|
||||
if (connector.id, credential.id, model.id) in ongoing:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
connector, last_attempt, model, db_session
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
|
||||
# CC-Pair will have the status that it should for the primary index
|
||||
# Will be re-sync-ed once the indices are swapped
|
||||
if model.status == IndexModelStatus.PRESENT:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
attempt_status=IndexingStatus.NOT_STARTED,
|
||||
)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
db_session: Session, existing_jobs: dict[int, Future]
|
||||
) -> dict[int, Future]:
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
|
||||
# clean up completed jobs
|
||||
for attempt_id, job in existing_jobs.items():
|
||||
# do nothing for ongoing jobs
|
||||
if not job.done():
|
||||
continue
|
||||
|
||||
job.release()
|
||||
del existing_jobs_copy[attempt_id]
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
if not index_attempt:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
||||
"up indexing jobs"
|
||||
)
|
||||
continue
|
||||
|
||||
if index_attempt.status == IndexingStatus.IN_PROGRESS:
|
||||
mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for attempt_id, job in existing_jobs.items():
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
|
||||
# clean up in-progress jobs that were never completed
|
||||
connectors = fetch_connectors(db_session)
|
||||
for connector in connectors:
|
||||
in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
connector.id, db_session
|
||||
)
|
||||
for index_attempt in in_progress_indexing_attempts:
|
||||
if index_attempt.id in existing_jobs:
|
||||
# check to see if the job has been updated in the last hour, if not
|
||||
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# on the fact that the `time_updated` field is constantly updated every
|
||||
# batch of documents indexed
|
||||
current_db_time = get_db_current_time(db_session=db_session)
|
||||
time_since_update = current_db_time - index_attempt.time_updated
|
||||
if time_since_update.seconds > 60 * 60:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason="Indexing run frozen - no updates in last hour. "
|
||||
"The run will be re-attempted at next scheduled indexing time.",
|
||||
)
|
||||
else:
|
||||
# If job isn't known, simply mark it as failed
|
||||
mark_run_failed(
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done() and not _is_indexing_job_marked_as_finished(
|
||||
index_attempt
|
||||
):
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
del existing_jobs_copy[attempt_id]
|
||||
|
||||
if not index_attempt:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
||||
"up indexing jobs"
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
index_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
or job.status == "error"
|
||||
):
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
|
||||
# clean up in-progress jobs that were never completed
|
||||
connectors = fetch_connectors(db_session)
|
||||
for connector in connectors:
|
||||
in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
connector.id, db_session
|
||||
)
|
||||
for index_attempt in in_progress_indexing_attempts:
|
||||
if index_attempt.id in existing_jobs:
|
||||
# check to see if the job has been updated in last `timeout_hours` hours, if not
|
||||
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# on the fact that the `time_updated` field is constantly updated every
|
||||
# batch of documents indexed
|
||||
current_db_time = get_db_current_time(db_session=db_session)
|
||||
time_since_update = current_db_time - index_attempt.time_updated
|
||||
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason="Indexing run frozen - no updates in the last three hours. "
|
||||
"The run will be re-attempted at next scheduled indexing time.",
|
||||
)
|
||||
else:
|
||||
# If job isn't known, simply mark it as failed
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastores (e.g. Qdrant / Typesense or Vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
|
||||
def _get_document_generator(
|
||||
db_session: Session, attempt: IndexAttempt
|
||||
) -> tuple[GenerateDocumentsOutput, float]:
|
||||
# "official" timestamp for this run
|
||||
# used for setting time bounds when fetching updates from apps and
|
||||
# is stored in the DB as the last successful run time if this run succeeds
|
||||
run_time = time.time()
|
||||
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
|
||||
run_time_str = run_dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
task = attempt.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector, new_credential_json = instantiate_connector(
|
||||
attempt.connector.source,
|
||||
task,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential.credential_json,
|
||||
)
|
||||
if new_credential_json is not None:
|
||||
backend_update_credential_json(
|
||||
attempt.credential, new_credential_json, db_session
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
raise e
|
||||
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
)
|
||||
last_run_time = get_last_successful_attempt_time(
|
||||
attempt.connector_id, attempt.credential_id, db_session
|
||||
)
|
||||
last_run_time_str = datetime.fromtimestamp(
|
||||
last_run_time, tz=timezone.utc
|
||||
).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info(
|
||||
f"Polling for updates between {last_run_time_str} and {run_time_str}"
|
||||
)
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=last_run_time, end=run_time
|
||||
)
|
||||
|
||||
else:
|
||||
# Event types cannot be handled by a background type
|
||||
raise RuntimeError(f"Invalid task type: {task}")
|
||||
|
||||
return doc_batch_generator, run_time
|
||||
|
||||
doc_batch_generator, run_time = _get_document_generator(db_session, index_attempt)
|
||||
|
||||
def _index(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
doc_batch_generator: GenerateDocumentsOutput,
|
||||
run_time: float,
|
||||
) -> None:
|
||||
indexing_pipeline = build_indexing_pipeline()
|
||||
|
||||
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
|
||||
db_connector = attempt.connector
|
||||
db_credential = attempt.credential
|
||||
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
run_dt=run_dt,
|
||||
)
|
||||
|
||||
try:
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
for doc_batch in doc_batch_generator:
|
||||
logger.debug(
|
||||
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
|
||||
)
|
||||
|
||||
index_user_id = (
|
||||
None if db_credential.public_doc else db_credential.user_id
|
||||
)
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
documents=doc_batch,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
user_id=index_user_id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
),
|
||||
)
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
index_attempt=attempt,
|
||||
num_docs_indexed=document_count,
|
||||
)
|
||||
|
||||
# check if connector is disabled mid run and stop if so
|
||||
db_session.refresh(db_connector)
|
||||
if db_connector.disabled:
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
mark_attempt_succeeded(attempt, db_session)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.SUCCESS,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_dt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexed or updated {document_count} total documents for a total of {chunk_count} chunks"
|
||||
)
|
||||
logger.info(
|
||||
f"Connector successfully finished, elapsed time: {time.time() - run_time} seconds"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Failed connector elapsed time: {time.time() - run_time} seconds"
|
||||
)
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
# The last attempt won't be marked failed until the next cycle's check for still in-progress attempts
|
||||
# The connector_credential_pair is marked failed here though to reflect correctly in UI asap
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector.id,
|
||||
credential_id=attempt.credential.id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_dt,
|
||||
)
|
||||
raise e
|
||||
|
||||
_index(db_session, index_attempt, doc_batch_generator, run_time)
|
||||
|
||||
|
||||
def _run_indexing_entrypoint(index_attempt_id: int) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
try:
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
if attempt is None:
|
||||
raise RuntimeError(
|
||||
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Running indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector.id,
|
||||
credential_id=attempt.credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
)
|
||||
|
||||
_run_indexing(
|
||||
db_session=db_session,
|
||||
index_attempt=attempt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
|
||||
|
||||
def kickoff_indexing_jobs(
|
||||
db_session: Session,
|
||||
existing_jobs: dict[int, Future],
|
||||
client: Client,
|
||||
) -> dict[int, Future]:
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
client: Client | SimpleJobClient,
|
||||
secondary_client: Client | SimpleJobClient,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
with Session(engine) as db_session:
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.embedding_model)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
if attempt.id not in existing_jobs
|
||||
]
|
||||
|
||||
new_indexing_attempts = get_not_started_index_attempts(db_session)
|
||||
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
|
||||
|
||||
if not new_indexing_attempts:
|
||||
return existing_jobs
|
||||
|
||||
for attempt in new_indexing_attempts:
|
||||
for attempt, embedding_model in new_indexing_attempts:
|
||||
use_secondary_index = (
|
||||
embedding_model.status == IndexModelStatus.FUTURE
|
||||
if embedding_model is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
)
|
||||
mark_attempt_failed(attempt, db_session, failure_reason="Connector is null")
|
||||
with Session(engine) as db_session:
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Connector is null"
|
||||
)
|
||||
continue
|
||||
if attempt.credential is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Credential is null"
|
||||
)
|
||||
with Session(engine) as db_session:
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Credential is null"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Kicking off indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
run = client.submit(_run_indexing_entrypoint, attempt.id, pure=False)
|
||||
existing_jobs_copy[attempt.id] = run
|
||||
if use_secondary_index:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
||||
)
|
||||
else:
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
||||
)
|
||||
|
||||
if run:
|
||||
secondary_str = "(secondary index) " if use_secondary_index else ""
|
||||
logger.info(
|
||||
f"Kicked off {secondary_str}"
|
||||
f"indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
existing_jobs_copy[attempt.id] = run
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
cluster = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
# there are warning about high memory usage + "Event loop unresponsive"
|
||||
# which are not relevant to us since our workers are expected to use a
|
||||
# lot of memory + involve CPU intensive tasks that will not relinquish
|
||||
# the event loop
|
||||
silence_logs=logging.ERROR,
|
||||
def check_index_swap(db_session: Session) -> None:
|
||||
"""Get count of cc-pairs and count of index_attempts for the new model grouped by
|
||||
connector + credential, if it's the same, then assume new index is done building.
|
||||
This does not take into consideration if the attempt failed or not"""
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = len(all_cc_pairs) - 1
|
||||
embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
|
||||
if not embedding_model:
|
||||
return
|
||||
|
||||
unique_cc_indexings = count_unique_cc_pairs_with_index_attempts(
|
||||
embedding_model_id=embedding_model.id, db_session=db_session
|
||||
)
|
||||
client = Client(cluster)
|
||||
existing_jobs: dict[int, Future] = {}
|
||||
|
||||
if unique_cc_indexings > cc_pair_count:
|
||||
raise RuntimeError("More unique indexings than cc pairs, should not occur")
|
||||
|
||||
if cc_pair_count == unique_cc_indexings:
|
||||
# Swap indices
|
||||
now_old_embedding_model = get_current_db_embedding_model(db_session)
|
||||
update_embedding_model_status(
|
||||
embedding_model=now_old_embedding_model,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_embedding_model_status(
|
||||
embedding_model=embedding_model,
|
||||
new_status=IndexModelStatus.PRESENT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Recount aggregates
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
if DASK_JOB_CLIENT_ENABLED:
|
||||
cluster_primary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
# there are warning about high memory usage + "Event loop unresponsive"
|
||||
# which are not relevant to us since our workers are expected to use a
|
||||
# lot of memory + involve CPU intensive tasks that will not relinquish
|
||||
# the event loop
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
cluster_secondary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
client_primary = Client(cluster_primary)
|
||||
client_secondary = Client(cluster_secondary)
|
||||
if LOG_LEVEL.lower() == "debug":
|
||||
client_primary.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
with Session(engine) as db_session:
|
||||
@@ -456,15 +417,24 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info(f"Running update, current UTC time: {start_time_utc}")
|
||||
|
||||
if existing_jobs:
|
||||
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
|
||||
logger.info(
|
||||
"Found existing indexing jobs: "
|
||||
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(engine, expire_on_commit=False) as db_session:
|
||||
existing_jobs = cleanup_indexing_jobs(
|
||||
db_session=db_session, existing_jobs=existing_jobs
|
||||
)
|
||||
create_indexing_jobs(db_session=db_session, existing_jobs=existing_jobs)
|
||||
existing_jobs = kickoff_indexing_jobs(
|
||||
db_session=db_session, existing_jobs=existing_jobs, client=client
|
||||
)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
check_index_swap(db_session)
|
||||
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
|
||||
create_indexing_jobs(existing_jobs=existing_jobs)
|
||||
existing_jobs = kickoff_indexing_jobs(
|
||||
existing_jobs=existing_jobs,
|
||||
client=client_primary,
|
||||
secondary_client=client_secondary,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run update due to {e}")
|
||||
sleep_time = delay - (time.time() - start)
|
||||
@@ -472,8 +442,16 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Warming up Embedding Model(s)")
|
||||
warm_up_models(indexer_only=True)
|
||||
def update__main() -> None:
|
||||
# needed for CUDA to work with multiprocessing
|
||||
# NOTE: needs to be done on application startup
|
||||
# before any other torch code has been run
|
||||
if not DASK_JOB_CLIENT_ENABLED:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
logger.info("Starting Indexing Loop")
|
||||
update_loop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
update__main()
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def interval_run_job(
|
||||
job: Callable[[], Any], delay: int | float, emit_job_start_log: bool = True
|
||||
) -> None:
|
||||
while True:
|
||||
start = time.time()
|
||||
if emit_job_start_log:
|
||||
logger.info(f"Running '{job.__name__}', current time: {time.ctime(start)}")
|
||||
try:
|
||||
job()
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run update due to {e}")
|
||||
sleep_time = delay - (time.time() - start)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
@@ -1,197 +0,0 @@
|
||||
from slack_sdk.models.blocks import ActionsBlock
|
||||
from slack_sdk.models.blocks import Block
|
||||
from slack_sdk.models.blocks import ButtonElement
|
||||
from slack_sdk.models.blocks import ConfirmObject
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
from slack_sdk.models.blocks import HeaderBlock
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
|
||||
from danswer.bots.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.bots.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.bots.slack.utils import build_feedback_block_id
|
||||
from danswer.bots.slack.utils import remove_slack_text_interactions
|
||||
from danswer.bots.slack.utils import translate_vespa_highlight_to_slack
|
||||
from danswer.configs.app_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||
from danswer.configs.app_configs import ENABLE_SLACK_DOC_FEEDBACK
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.server.models import SearchDoc
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
|
||||
_MAX_BLURB_LEN = 75
|
||||
|
||||
|
||||
def build_qa_feedback_block(query_event_id: int) -> Block:
|
||||
return ActionsBlock(
|
||||
block_id=build_feedback_block_id(query_event_id),
|
||||
elements=[
|
||||
ButtonElement(
|
||||
action_id=LIKE_BLOCK_ACTION_ID,
|
||||
text="👍",
|
||||
style="primary",
|
||||
),
|
||||
ButtonElement(
|
||||
action_id=DISLIKE_BLOCK_ACTION_ID,
|
||||
text="👎",
|
||||
style="danger",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def build_doc_feedback_block(
|
||||
query_event_id: int,
|
||||
document_id: str,
|
||||
document_rank: int,
|
||||
) -> Block:
|
||||
return ActionsBlock(
|
||||
block_id=build_feedback_block_id(query_event_id, document_id, document_rank),
|
||||
elements=[
|
||||
ButtonElement(
|
||||
action_id=SearchFeedbackType.ENDORSE.value,
|
||||
text="⬆",
|
||||
style="primary",
|
||||
confirm=ConfirmObject(
|
||||
title="Endorse this Document",
|
||||
text="This is a good source of information and should be shown more often!",
|
||||
),
|
||||
),
|
||||
ButtonElement(
|
||||
action_id=SearchFeedbackType.REJECT.value,
|
||||
text="⬇",
|
||||
style="danger",
|
||||
confirm=ConfirmObject(
|
||||
title="Reject this Document",
|
||||
text="This is a bad source of information and should be shown less often.",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def build_documents_blocks(
|
||||
documents: list[SearchDoc],
|
||||
query_event_id: int,
|
||||
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
|
||||
include_feedback: bool = ENABLE_SLACK_DOC_FEEDBACK,
|
||||
) -> list[Block]:
|
||||
seen_docs_identifiers = set()
|
||||
section_blocks: list[Block] = [HeaderBlock(text="Reference Documents")]
|
||||
included_docs = 0
|
||||
for rank, d in enumerate(documents):
|
||||
if d.document_id in seen_docs_identifiers:
|
||||
continue
|
||||
seen_docs_identifiers.add(d.document_id)
|
||||
|
||||
doc_sem_id = d.semantic_identifier
|
||||
if d.source_type == DocumentSource.SLACK.value:
|
||||
doc_sem_id = "#" + doc_sem_id
|
||||
|
||||
used_chars = len(doc_sem_id) + 3
|
||||
match_str = translate_vespa_highlight_to_slack(d.match_highlights, used_chars)
|
||||
|
||||
included_docs += 1
|
||||
|
||||
section_blocks.append(
|
||||
SectionBlock(
|
||||
text=f"<{d.link}|{doc_sem_id}>:\n>{remove_slack_text_interactions(match_str)}"
|
||||
),
|
||||
)
|
||||
|
||||
if include_feedback:
|
||||
section_blocks.append(
|
||||
build_doc_feedback_block(
|
||||
query_event_id=query_event_id,
|
||||
document_id=d.document_id,
|
||||
document_rank=rank,
|
||||
),
|
||||
)
|
||||
|
||||
section_blocks.append(DividerBlock())
|
||||
|
||||
if included_docs >= num_docs_to_display:
|
||||
break
|
||||
|
||||
return section_blocks
|
||||
|
||||
|
||||
def build_quotes_block(
|
||||
quotes: list[DanswerQuote],
|
||||
) -> list[Block]:
|
||||
quote_lines: list[str] = []
|
||||
doc_to_quotes: dict[str, list[str]] = {}
|
||||
doc_to_link: dict[str, str] = {}
|
||||
doc_to_sem_id: dict[str, str] = {}
|
||||
for q in quotes:
|
||||
quote = q.quote
|
||||
doc_id = q.document_id
|
||||
doc_link = q.link
|
||||
doc_name = q.semantic_identifier
|
||||
if doc_link and doc_name and doc_id and quote:
|
||||
if doc_id not in doc_to_quotes:
|
||||
doc_to_quotes[doc_id] = [quote]
|
||||
doc_to_link[doc_id] = doc_link
|
||||
doc_to_sem_id[doc_id] = (
|
||||
doc_name
|
||||
if q.source_type != DocumentSource.SLACK.value
|
||||
else "#" + doc_name
|
||||
)
|
||||
else:
|
||||
doc_to_quotes[doc_id].append(quote)
|
||||
|
||||
for doc_id, quote_strs in doc_to_quotes.items():
|
||||
quotes_str_clean = [
|
||||
replace_whitespaces_w_space(q_str).strip() for q_str in quote_strs
|
||||
]
|
||||
longest_quotes = sorted(quotes_str_clean, key=len, reverse=True)[:5]
|
||||
single_quote_str = "\n".join([f"```{q_str}```" for q_str in longest_quotes])
|
||||
link = doc_to_link[doc_id]
|
||||
sem_id = doc_to_sem_id[doc_id]
|
||||
quote_lines.append(
|
||||
f"<{link}|{sem_id}>:\n{remove_slack_text_interactions(single_quote_str)}"
|
||||
)
|
||||
|
||||
if not doc_to_quotes:
|
||||
return []
|
||||
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
|
||||
|
||||
def build_qa_response_blocks(
|
||||
query_event_id: int,
|
||||
answer: str | None,
|
||||
quotes: list[DanswerQuote] | None,
|
||||
) -> list[Block]:
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
ai_answer_header = HeaderBlock(text="AI Answer")
|
||||
|
||||
if not answer:
|
||||
answer_block = SectionBlock(
|
||||
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
|
||||
)
|
||||
else:
|
||||
answer_block = SectionBlock(text=remove_slack_text_interactions(answer))
|
||||
if quotes:
|
||||
quotes_blocks = build_quotes_block(quotes)
|
||||
|
||||
# if no quotes OR `build_quotes_block()` did not give back any blocks
|
||||
if not quotes_blocks:
|
||||
quotes_blocks = [
|
||||
SectionBlock(
|
||||
text="*Warning*: no sources were quoted for this answer, so it may be unreliable 😔"
|
||||
)
|
||||
]
|
||||
|
||||
feedback_block = build_qa_feedback_block(query_event_id=query_event_id)
|
||||
return (
|
||||
[
|
||||
ai_answer_header,
|
||||
answer_block,
|
||||
feedback_block,
|
||||
]
|
||||
+ quotes_blocks
|
||||
+ [DividerBlock()]
|
||||
)
|
||||
@@ -1,2 +0,0 @@
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
@@ -1,61 +0,0 @@
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.bots.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.bots.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.bots.slack.utils import decompose_block_id
|
||||
from danswer.configs.constants import QAFeedbackType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.feedback import create_doc_retrieval_feedback
|
||||
from danswer.db.feedback import update_query_event_feedback
|
||||
|
||||
|
||||
def handle_slack_feedback(
|
||||
block_id: str,
|
||||
feedback_type: str,
|
||||
client: WebClient,
|
||||
user_id_to_post_confirmation: str,
|
||||
channel_id_to_post_confirmation: str,
|
||||
thread_ts_to_post_confirmation: str,
|
||||
) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
query_id, doc_id, doc_rank = decompose_block_id(block_id)
|
||||
|
||||
with Session(engine) as db_session:
|
||||
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
|
||||
update_query_event_feedback(
|
||||
feedback=QAFeedbackType.LIKE
|
||||
if feedback_type == LIKE_BLOCK_ACTION_ID
|
||||
else QAFeedbackType.DISLIKE,
|
||||
query_id=query_id,
|
||||
user_id=None, # no "user" for Slack bot for now
|
||||
db_session=db_session,
|
||||
)
|
||||
if feedback_type in [
|
||||
SearchFeedbackType.ENDORSE.value,
|
||||
SearchFeedbackType.REJECT.value,
|
||||
]:
|
||||
if doc_id is None or doc_rank is None:
|
||||
raise ValueError("Missing information for Document Feedback")
|
||||
|
||||
create_doc_retrieval_feedback(
|
||||
qa_event_id=query_id,
|
||||
document_id=doc_id,
|
||||
document_rank=doc_rank,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
clicked=False, # Not tracking this for Slack
|
||||
feedback=SearchFeedbackType.ENDORSE
|
||||
if feedback_type == SearchFeedbackType.ENDORSE.value
|
||||
else SearchFeedbackType.REJECT,
|
||||
)
|
||||
|
||||
# post message to slack confirming that feedback was received
|
||||
client.chat_postEphemeral(
|
||||
channel=channel_id_to_post_confirmation,
|
||||
user=user_id_to_post_confirmation,
|
||||
thread_ts=thread_ts_to_post_confirmation,
|
||||
text="Thanks for your feedback!",
|
||||
)
|
||||
@@ -1,161 +0,0 @@
|
||||
import logging
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.bots.slack.blocks import build_documents_blocks
|
||||
from danswer.bots.slack.blocks import build_qa_response_blocks
|
||||
from danswer.bots.slack.config import get_slack_bot_config_for_channel
|
||||
from danswer.bots.slack.utils import get_channel_name_from_id
|
||||
from danswer.bots.slack.utils import respond_in_thread
|
||||
from danswer.configs.app_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.app_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.app_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.app_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.configs.constants import DOCUMENT_SETS
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
|
||||
|
||||
def handle_message(
|
||||
msg: str,
|
||||
channel: str,
|
||||
message_ts_to_respond_to: str,
|
||||
client: WebClient,
|
||||
logger: logging.Logger,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
channel_name = get_channel_name_from_id(client=client, channel_id=channel)
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
channel_name=channel_name, db_session=db_session
|
||||
)
|
||||
document_set_names: list[str] | None = None
|
||||
validity_check_enabled = ENABLE_DANSWERBOT_REFLEXION
|
||||
if slack_bot_config and slack_bot_config.persona:
|
||||
document_set_names = [
|
||||
document_set.name
|
||||
for document_set in slack_bot_config.persona.document_sets
|
||||
]
|
||||
validity_check_enabled = slack_bot_config.channel_config.get(
|
||||
"answer_validity_check_enabled", validity_check_enabled
|
||||
)
|
||||
logger.info(
|
||||
"Found slack bot config for channel. Restricting bot to use document "
|
||||
f"sets: {document_set_names}, validity check enabled: {validity_check_enabled}"
|
||||
)
|
||||
|
||||
@retry(
|
||||
tries=num_retries,
|
||||
delay=0.25,
|
||||
backoff=2,
|
||||
logger=logger,
|
||||
)
|
||||
def _get_answer(question: QuestionRequest) -> QAResponse:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as db_session:
|
||||
# This also handles creating the query event in postgres
|
||||
answer = answer_qa_query(
|
||||
question=question,
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
real_time_flow=False,
|
||||
enable_reflexion=validity_check_enabled,
|
||||
)
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
try:
|
||||
answer = _get_answer(
|
||||
QuestionRequest(
|
||||
query=msg,
|
||||
collection=DOCUMENT_INDEX_NAME,
|
||||
use_keyword=False, # always use semantic search when handling Slack messages
|
||||
filters=[{DOCUMENT_SETS: document_set_names}]
|
||||
if document_set_names
|
||||
else None,
|
||||
offset=None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
f"in {num_retries} attempts"
|
||||
)
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=f"Encountered exception when trying to answer: \n\n```{e}```",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
return
|
||||
|
||||
if answer.eval_res_valid is False:
|
||||
logger.info(
|
||||
"Answer was evaluated to be invalid, throwing it away without responding."
|
||||
)
|
||||
if answer.answer:
|
||||
logger.debug(answer.answer)
|
||||
return
|
||||
|
||||
if not answer.top_ranked_docs:
|
||||
logger.error(f"Unable to answer question: '{msg}' - no documents found")
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text="Found no documents when trying to answer. Did you index any documents?",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
return
|
||||
|
||||
if not answer.answer and disable_docs_only_answer:
|
||||
logger.info(
|
||||
"Unable to find answer - not responding since the "
|
||||
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
|
||||
)
|
||||
return
|
||||
|
||||
# convert raw response into "nicely" formatted Slack message
|
||||
answer_blocks = build_qa_response_blocks(
|
||||
query_event_id=answer.query_event_id,
|
||||
answer=answer.answer,
|
||||
quotes=answer.quotes,
|
||||
)
|
||||
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=answer.top_ranked_docs, query_event_id=answer.query_event_id
|
||||
)
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
blocks=answer_blocks + document_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to process message - could not respond in slack in {num_retries} attempts"
|
||||
)
|
||||
return
|
||||
@@ -1,179 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.socket_mode import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
|
||||
from danswer.bots.slack.handlers.handle_feedback import handle_slack_feedback
|
||||
from danswer.bots.slack.handlers.handle_message import handle_message
|
||||
from danswer.bots.slack.utils import decompose_block_id
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.server.slack_bot_management import get_tokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_CHANNEL_ID = "channel_id"
|
||||
|
||||
|
||||
class _ChannelIdAdapter(logging.LoggerAdapter):
|
||||
"""This is used to add the channel ID to all log messages
|
||||
emitted in this file"""
|
||||
|
||||
def process(
|
||||
self, msg: str, kwargs: MutableMapping[str, Any]
|
||||
) -> tuple[str, MutableMapping[str, Any]]:
|
||||
channel_id = self.extra.get(_CHANNEL_ID) if self.extra else None
|
||||
if channel_id:
|
||||
return f"[Channel ID: {channel_id}] {msg}", kwargs
|
||||
else:
|
||||
return msg, kwargs
|
||||
|
||||
|
||||
def _get_socket_client() -> SocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.danswer.dev/slack_bot_setup
|
||||
try:
|
||||
slack_bot_tokens = get_tokens()
|
||||
except ConfigNotFoundError:
|
||||
raise RuntimeError("Slack tokens not found")
|
||||
return SocketModeClient(
|
||||
# This app-level token will be used only for establishing a connection
|
||||
app_token=slack_bot_tokens.app_token,
|
||||
web_client=WebClient(token=slack_bot_tokens.bot_token),
|
||||
)
|
||||
|
||||
|
||||
def _process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> None:
|
||||
logger.info(f"Received Slack request of type: '{req.type}'")
|
||||
if req.type == "events_api":
|
||||
# Acknowledge the request immediately
|
||||
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||||
client.send_socket_mode_response(response)
|
||||
|
||||
event = cast(dict[str, Any], req.payload.get("event", {}))
|
||||
channel = cast(str | None, event.get("channel"))
|
||||
channel_specific_logger = _ChannelIdAdapter(
|
||||
logger, extra={_CHANNEL_ID: channel}
|
||||
)
|
||||
|
||||
# Ensure that the message is a new message + of expected type
|
||||
event_type = event.get("type")
|
||||
if event_type != "message":
|
||||
channel_specific_logger.info(
|
||||
f"Ignoring non-message event of type '{event_type}' for channel '{channel}'"
|
||||
)
|
||||
|
||||
# this should never happen, but we can't continue without a channel since
|
||||
# we can't send a response without it
|
||||
if not channel:
|
||||
channel_specific_logger.error("Found message without channel - skipping")
|
||||
return
|
||||
|
||||
message_subtype = event.get("subtype")
|
||||
# ignore things like channel_join, channel_leave, etc.
|
||||
# NOTE: "file_share" is just a message with a file attachment, so we
|
||||
# should not ignore it
|
||||
if message_subtype not in [None, "file_share"]:
|
||||
channel_specific_logger.info(
|
||||
f"Ignoring message with subtype '{message_subtype}' since is is a special message type"
|
||||
)
|
||||
return
|
||||
|
||||
if event.get("bot_profile"):
|
||||
channel_specific_logger.info("Ignoring message from bot")
|
||||
return
|
||||
|
||||
message_ts = event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
# Pick the root of the thread (if a thread exists)
|
||||
message_ts_to_respond_to = cast(str, thread_ts or message_ts)
|
||||
if thread_ts and message_ts != thread_ts:
|
||||
channel_specific_logger.info(
|
||||
"Skipping message since it is not the root of a thread"
|
||||
)
|
||||
return
|
||||
|
||||
msg = cast(str | None, event.get("text"))
|
||||
if not msg:
|
||||
channel_specific_logger.error("Unable to process empty message")
|
||||
return
|
||||
|
||||
# TODO: message should be enqueued and processed elsewhere,
|
||||
# but doing it here for now for simplicity
|
||||
handle_message(
|
||||
msg=msg,
|
||||
channel=channel,
|
||||
message_ts_to_respond_to=message_ts_to_respond_to,
|
||||
client=client.web_client,
|
||||
logger=cast(logging.Logger, channel_specific_logger),
|
||||
)
|
||||
|
||||
channel_specific_logger.info(
|
||||
f"Successfully processed message with ts: '{message_ts}'"
|
||||
)
|
||||
|
||||
# Handle button clicks
|
||||
if req.type == "interactive" and req.payload.get("type") == "block_actions":
|
||||
# Acknowledge the request immediately
|
||||
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||||
client.send_socket_mode_response(response)
|
||||
|
||||
actions = req.payload.get("actions")
|
||||
if not actions:
|
||||
logger.error("Unable to process block actions - no actions found")
|
||||
return
|
||||
|
||||
action = cast(dict[str, Any], actions[0])
|
||||
action_id = cast(str, action.get("action_id"))
|
||||
block_id = cast(str, action.get("block_id"))
|
||||
user_id = cast(str, req.payload["user"]["id"])
|
||||
channel_id = cast(str, req.payload["container"]["channel_id"])
|
||||
thread_ts = cast(str, req.payload["container"]["thread_ts"])
|
||||
|
||||
handle_slack_feedback(
|
||||
block_id=block_id,
|
||||
feedback_type=action_id,
|
||||
client=client.web_client,
|
||||
user_id_to_post_confirmation=user_id,
|
||||
channel_id_to_post_confirmation=channel_id,
|
||||
thread_ts_to_post_confirmation=thread_ts,
|
||||
)
|
||||
|
||||
query_event_id, _, _ = decompose_block_id(block_id)
|
||||
logger.info(f"Successfully handled QA feedback for event: {query_event_id}")
|
||||
|
||||
|
||||
def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> None:
|
||||
try:
|
||||
_process_slack_event(client=client, req=req)
|
||||
except Exception:
|
||||
logger.exception("Failed to process slack event")
|
||||
|
||||
|
||||
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
|
||||
# the slack bot in your workspace, and then add the bot to any channels you want to
|
||||
# try and answer questions for. Running this file will setup Danswer to listen to all
|
||||
# messages in those channels and attempt to answer them. As of now, it will only respond
|
||||
# to messages sent directly in the channel - it will not respond to messages sent within a
|
||||
# thread.
|
||||
#
|
||||
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
|
||||
# without issue.
|
||||
if __name__ == "__main__":
|
||||
socket_client = _get_socket_client()
|
||||
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
|
||||
|
||||
# Establish a WebSocket connection to the Socket Mode servers
|
||||
logger.info("Listening for messages from Slack...")
|
||||
socket_client.connect()
|
||||
|
||||
# Just not to stop this process
|
||||
from threading import Event
|
||||
|
||||
Event().wait()
|
||||
@@ -1,138 +0,0 @@
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import Block
|
||||
from slack_sdk.models.metadata import Metadata
|
||||
|
||||
from danswer.bots.slack.tokens import fetch_tokens
|
||||
from danswer.configs.app_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.constants import ID_SEPARATOR
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import UserIdReplacer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_web_client() -> WebClient:
|
||||
slack_tokens = fetch_tokens()
|
||||
return WebClient(token=slack_tokens.bot_token)
|
||||
|
||||
|
||||
@retry(
|
||||
tries=DANSWER_BOT_NUM_RETRIES,
|
||||
delay=0.25,
|
||||
backoff=2,
|
||||
logger=cast(logging.Logger, logger),
|
||||
)
|
||||
def respond_in_thread(
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
thread_ts: str,
|
||||
text: str | None = None,
|
||||
blocks: list[Block] | None = None,
|
||||
metadata: Metadata | None = None,
|
||||
unfurl: bool = True,
|
||||
) -> None:
|
||||
if not text and not blocks:
|
||||
raise ValueError("One of `text` or `blocks` must be provided")
|
||||
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"Unable to post message: {response}")
|
||||
|
||||
|
||||
def build_feedback_block_id(
|
||||
query_event_id: int,
|
||||
document_id: str | None = None,
|
||||
document_rank: int | None = None,
|
||||
) -> str:
|
||||
unique_prefix = "".join(random.choice(string.ascii_letters) for _ in range(10))
|
||||
if document_id is not None:
|
||||
if not document_id or document_rank is None:
|
||||
raise ValueError("Invalid document, missing information")
|
||||
if ID_SEPARATOR in document_id:
|
||||
raise ValueError(
|
||||
"Separator pattern should not already exist in document id"
|
||||
)
|
||||
block_id = ID_SEPARATOR.join(
|
||||
[str(query_event_id), document_id, str(document_rank)]
|
||||
)
|
||||
else:
|
||||
block_id = str(query_event_id)
|
||||
|
||||
return unique_prefix + ID_SEPARATOR + block_id
|
||||
|
||||
|
||||
def decompose_block_id(block_id: str) -> tuple[int, str | None, int | None]:
|
||||
"""Decompose into query_id, document_id, document_rank, see above function"""
|
||||
try:
|
||||
components = block_id.split(ID_SEPARATOR)
|
||||
if len(components) != 2 and len(components) != 4:
|
||||
raise ValueError("Block ID does not contain right number of elements")
|
||||
|
||||
if len(components) == 2:
|
||||
return int(components[-1]), None, None
|
||||
|
||||
return int(components[1]), components[2], int(components[3])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise ValueError("Received invalid Feedback Block Identifier")
|
||||
|
||||
|
||||
def translate_vespa_highlight_to_slack(match_strs: list[str], used_chars: int) -> str:
|
||||
def _replace_highlight(s: str) -> str:
|
||||
s = re.sub(r"(?<=[^\s])<hi>(.*?)</hi>", r"\1", s)
|
||||
s = s.replace("</hi>", "*").replace("<hi>", "*")
|
||||
return s
|
||||
|
||||
final_matches = [
|
||||
replace_whitespaces_w_space(_replace_highlight(match_str)).strip()
|
||||
for match_str in match_strs
|
||||
if match_str
|
||||
]
|
||||
combined = "... ".join(final_matches)
|
||||
|
||||
# Slack introduces "Show More" after 300 on desktop which is ugly
|
||||
# But don't trim the message if there is still a highlight after 300 chars
|
||||
remaining = 300 - used_chars
|
||||
if len(combined) > remaining and "*" not in combined[remaining:]:
|
||||
combined = combined[: remaining - 3] + "..."
|
||||
|
||||
return combined
|
||||
|
||||
|
||||
def remove_slack_text_interactions(slack_str: str) -> str:
|
||||
slack_str = UserIdReplacer.replace_tags_basic(slack_str)
|
||||
slack_str = UserIdReplacer.replace_channels_basic(slack_str)
|
||||
slack_str = UserIdReplacer.replace_special_mentions(slack_str)
|
||||
slack_str = UserIdReplacer.replace_links(slack_str)
|
||||
slack_str = UserIdReplacer.add_zero_width_whitespace_after_tag(slack_str)
|
||||
return slack_str
|
||||
|
||||
|
||||
def get_channel_from_id(client: WebClient, channel_id: str) -> dict[str, Any]:
|
||||
response = client.conversations_info(channel=channel_id)
|
||||
response.validate()
|
||||
return response["channel"]
|
||||
|
||||
|
||||
def get_channel_name_from_id(client: WebClient, channel_id: str) -> str:
|
||||
return get_channel_from_id(client, channel_id)["name"]
|
||||
@@ -1,363 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chat.chat_prompts import build_combined_query
|
||||
from danswer.chat.chat_prompts import DANSWER_TOOL_NAME
|
||||
from danswer.chat.chat_prompts import form_tool_followup_text
|
||||
from danswer.chat.chat_prompts import form_user_prompt_text
|
||||
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
||||
from danswer.chat.tools import call_tool
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import extract_embedded_json
|
||||
from danswer.utils.text_processing import has_unescaped_quote
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
LLM_CHAT_FAILURE_MSG = "The large-language-model failed to generate a valid response."
|
||||
|
||||
|
||||
def _parse_embedded_json_streamed_response(
|
||||
tokens: Iterator[str],
|
||||
) -> Iterator[DanswerAnswerPiece | DanswerChatModelOut]:
|
||||
final_answer = False
|
||||
just_start_stream = False
|
||||
model_output = ""
|
||||
hold = ""
|
||||
finding_end = 0
|
||||
for token in tokens:
|
||||
model_output += token
|
||||
hold += token
|
||||
|
||||
if (
|
||||
final_answer is False
|
||||
and '"action":"finalanswer",' in model_output.lower().replace(" ", "")
|
||||
):
|
||||
final_answer = True
|
||||
|
||||
if final_answer and '"actioninput":"' in model_output.lower().replace(
|
||||
" ", ""
|
||||
).replace("_", ""):
|
||||
if not just_start_stream:
|
||||
just_start_stream = True
|
||||
hold = ""
|
||||
|
||||
if has_unescaped_quote(hold):
|
||||
finding_end += 1
|
||||
hold = hold[: hold.find('"')]
|
||||
|
||||
if finding_end <= 1:
|
||||
if finding_end == 1:
|
||||
finding_end += 1
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=hold)
|
||||
hold = ""
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
model_final = extract_embedded_json(model_output)
|
||||
if "action" not in model_final or "action_input" not in model_final:
|
||||
raise ValueError("Model did not provide all required action values")
|
||||
|
||||
yield DanswerChatModelOut(
|
||||
model_raw=model_output,
|
||||
action=model_final["action"],
|
||||
action_input=model_final["action_input"],
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def _find_last_index(
|
||||
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
|
||||
) -> int:
|
||||
"""From the back, find the index of the last element to include
|
||||
before the list exceeds the maximum"""
|
||||
running_sum = 0
|
||||
|
||||
last_ind = 0
|
||||
for i in range(len(lst) - 1, -1, -1):
|
||||
running_sum += lst[i]
|
||||
if running_sum > max_prompt_tokens:
|
||||
last_ind = i + 1
|
||||
break
|
||||
if last_ind >= len(lst):
|
||||
raise ValueError("Last message alone is too large!")
|
||||
return last_ind
|
||||
|
||||
|
||||
def danswer_chat_retrieval(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
llm: LLM,
|
||||
user_id: UUID | None,
|
||||
) -> str:
|
||||
if history:
|
||||
query_combination_msgs = build_combined_query(query_message, history)
|
||||
reworded_query = llm.invoke(query_combination_msgs)
|
||||
else:
|
||||
reworded_query = query_message.message
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
reworded_query,
|
||||
user_id=user_id,
|
||||
filters=None,
|
||||
datastore=get_default_document_index(),
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return "No results found"
|
||||
|
||||
if unranked_chunks:
|
||||
ranked_chunks.extend(unranked_chunks)
|
||||
|
||||
filtered_ranked_chunks = [
|
||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||
]
|
||||
|
||||
# get all chunks that fit into the token limit
|
||||
usable_chunks = get_usable_chunks(
|
||||
chunks=filtered_ranked_chunks,
|
||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
|
||||
)
|
||||
|
||||
return format_danswer_chunks_for_chat(usable_chunks)
|
||||
|
||||
|
||||
def _drop_messages_history_overflow(
|
||||
system_msg: BaseMessage | None,
|
||||
system_token_count: int,
|
||||
history_msgs: list[BaseMessage],
|
||||
history_token_counts: list[int],
|
||||
final_msg: BaseMessage,
|
||||
final_msg_token_count: int,
|
||||
) -> list[BaseMessage]:
|
||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||
prompt template must be included"""
|
||||
|
||||
if len(history_msgs) != len(history_token_counts):
|
||||
# This should never happen
|
||||
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
||||
|
||||
prompt: list[BaseMessage] = []
|
||||
|
||||
# Start dropping from the history if necessary
|
||||
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
||||
ind_prev_msg_start = _find_last_index(all_tokens)
|
||||
|
||||
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
||||
prompt.append(system_msg)
|
||||
|
||||
prompt.extend(history_msgs[ind_prev_msg_start:])
|
||||
|
||||
prompt.append(final_msg)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def llm_contextless_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
tokenizer: Callable | None = None,
|
||||
system_text: str | None = None,
|
||||
) -> Iterator[str]:
|
||||
try:
|
||||
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
|
||||
|
||||
if system_text:
|
||||
tokenizer = tokenizer or get_default_llm_tokenizer()
|
||||
system_tokens = len(tokenizer(system_text))
|
||||
system_msg = SystemMessage(content=system_text)
|
||||
|
||||
message_tokens = [msg.token_count for msg in messages] + [system_tokens]
|
||||
else:
|
||||
message_tokens = [msg.token_count for msg in messages]
|
||||
|
||||
last_msg_ind = _find_last_index(message_tokens)
|
||||
|
||||
remaining_user_msgs = prompt_msgs[last_msg_ind:]
|
||||
if not remaining_user_msgs:
|
||||
raise ValueError("Last user message is too long!")
|
||||
|
||||
if system_text:
|
||||
all_msgs = [system_msg] + remaining_user_msgs
|
||||
else:
|
||||
all_msgs = remaining_user_msgs
|
||||
|
||||
return get_default_llm().stream(all_msgs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
||||
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
|
||||
|
||||
|
||||
def llm_contextual_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona,
|
||||
user_id: UUID | None,
|
||||
tokenizer: Callable,
|
||||
) -> Iterator[str]:
|
||||
retrieval_enabled = persona.retrieval_enabled
|
||||
system_text = persona.system_text
|
||||
tool_text = persona.tools_text
|
||||
hint_text = persona.hint_text
|
||||
|
||||
last_message = messages[-1]
|
||||
previous_messages = messages[:-1]
|
||||
previous_msgs_as_basemessage = [
|
||||
translate_danswer_msg_to_langchain(msg) for msg in previous_messages
|
||||
]
|
||||
|
||||
# Failure reasons include:
|
||||
# - Invalid LLM output, wrong format or wrong/missing keys
|
||||
# - No "Final Answer" from model after tool calling
|
||||
# - LLM times out or is otherwise unavailable
|
||||
# - Calling invalid tool or tool call fails
|
||||
# - Last message has more tokens than model is set to accept
|
||||
# - Missing user input
|
||||
try:
|
||||
if not last_message.message:
|
||||
raise ValueError("User chat message is empty.")
|
||||
|
||||
# Build the prompt using the last user message
|
||||
user_text = form_user_prompt_text(
|
||||
query=last_message.message,
|
||||
tool_text=tool_text,
|
||||
hint_text=hint_text,
|
||||
)
|
||||
last_user_msg = HumanMessage(content=user_text)
|
||||
|
||||
# Count tokens once to reuse
|
||||
previous_msg_token_counts = [msg.token_count for msg in previous_messages]
|
||||
system_tokens = len(tokenizer(system_text)) if system_text else 0
|
||||
last_user_msg_tokens = len(tokenizer(user_text))
|
||||
|
||||
prompt = _drop_messages_history_overflow(
|
||||
system_msg=SystemMessage(content=system_text) if system_text else None,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=previous_msgs_as_basemessage,
|
||||
history_token_counts=previous_msg_token_counts,
|
||||
final_msg=last_user_msg,
|
||||
final_msg_token_count=last_user_msg_tokens,
|
||||
)
|
||||
|
||||
llm = get_default_llm()
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
|
||||
final_result: DanswerChatModelOut | None = None
|
||||
final_answer_streamed = False
|
||||
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result.answer_piece
|
||||
final_answer_streamed = True
|
||||
|
||||
if isinstance(result, DanswerChatModelOut):
|
||||
final_result = result
|
||||
break
|
||||
|
||||
if final_answer_streamed:
|
||||
return
|
||||
|
||||
if final_result is None:
|
||||
raise RuntimeError("Model output finished without final output parsing.")
|
||||
|
||||
if (
|
||||
retrieval_enabled
|
||||
and final_result.action.lower() == DANSWER_TOOL_NAME.lower()
|
||||
):
|
||||
tool_result_str = danswer_chat_retrieval(
|
||||
query_message=last_message,
|
||||
history=previous_messages,
|
||||
llm=llm,
|
||||
user_id=user_id,
|
||||
)
|
||||
else:
|
||||
tool_result_str = call_tool(final_result, user_id=user_id)
|
||||
|
||||
# The AI's tool calling message
|
||||
tool_call_msg_text = final_result.model_raw
|
||||
tool_call_msg_token_count = len(tokenizer(tool_call_msg_text))
|
||||
|
||||
# Create the new message to use the results of the tool call
|
||||
tool_followup_text = form_tool_followup_text(
|
||||
tool_output=tool_result_str,
|
||||
query=last_message.message,
|
||||
hint_text=hint_text,
|
||||
)
|
||||
tool_followup_msg = HumanMessage(content=tool_followup_text)
|
||||
tool_followup_tokens = len(tokenizer(tool_followup_text))
|
||||
|
||||
# Drop previous messages, the drop order goes: previous messages in the history,
|
||||
# the last user prompt and generated intermediate messages from this recent prompt,
|
||||
# the system message, then finally the tool message that was the last thing generated
|
||||
follow_up_prompt = _drop_messages_history_overflow(
|
||||
system_msg=SystemMessage(content=system_text) if system_text else None,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=previous_msgs_as_basemessage
|
||||
+ [last_user_msg, AIMessage(content=tool_call_msg_text)],
|
||||
history_token_counts=previous_msg_token_counts
|
||||
+ [last_user_msg_tokens, tool_call_msg_token_count],
|
||||
final_msg=tool_followup_msg,
|
||||
final_msg_token_count=tool_followup_tokens,
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(follow_up_prompt)
|
||||
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result.answer_piece
|
||||
final_answer_streamed = True
|
||||
|
||||
if final_answer_streamed is False:
|
||||
raise RuntimeError("LLM did not to produce a Final Answer after tool call")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield LLM_CHAT_FAILURE_MSG
|
||||
|
||||
|
||||
def llm_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona | None,
|
||||
user_id: UUID | None,
|
||||
tokenizer: Callable,
|
||||
) -> Iterator[str]:
|
||||
# Common error cases to keep in mind:
|
||||
# - User asks question about something long ago, due to context limit, the message is dropped
|
||||
# - Tool use gives wrong/irrelevant results, model gets confused by the noise
|
||||
# - Model is too weak of an LLM, fails to follow instructions
|
||||
# - Bad persona design leads to confusing instructions to the model
|
||||
# - Bad configurations, too small token limit, mismatched tokenizer to LLM, etc.
|
||||
if persona is None:
|
||||
return llm_contextless_chat_answer(messages)
|
||||
|
||||
elif persona.retrieval_enabled is False and persona.tools_text is None:
|
||||
return llm_contextless_chat_answer(
|
||||
messages, tokenizer, system_text=persona.system_text
|
||||
)
|
||||
|
||||
return llm_contextual_chat_answer(
|
||||
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
||||
)
|
||||
@@ -1,191 +0,0 @@
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
|
||||
DANSWER_TOOL_NAME = "Current Search"
|
||||
DANSWER_TOOL_DESCRIPTION = (
|
||||
"A search tool that can find information on any topic "
|
||||
"including up to date and proprietary knowledge."
|
||||
)
|
||||
|
||||
DANSWER_SYSTEM_MSG = (
|
||||
"Given a conversation (between Human and Assistant) and a final message from Human, "
|
||||
"rewrite the last message to be a standalone question that captures required/relevant context from the previous "
|
||||
"conversation messages."
|
||||
)
|
||||
|
||||
TOOL_TEMPLATE = """
|
||||
TOOLS
|
||||
------
|
||||
You can use tools to look up information that may be helpful in answering the user's \
|
||||
original question. The available tools are:
|
||||
|
||||
{tool_overviews}
|
||||
|
||||
RESPONSE FORMAT INSTRUCTIONS
|
||||
----------------------------
|
||||
When responding to me, please output a response in one of two formats:
|
||||
|
||||
**Option 1:**
|
||||
Use this if you want to use a tool. Markdown code snippet formatted in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": string, \\ The action to take. Must be one of {tool_names}
|
||||
"action_input": string \\ The input to the action
|
||||
}}
|
||||
```
|
||||
|
||||
**Option #2:**
|
||||
Use this if you want to respond directly to the user. Markdown code snippet formatted in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": "Final Answer",
|
||||
"action_input": string \\ You should put what you want to return to use here
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
TOOL_LESS_PROMPT = """
|
||||
Respond with a markdown code snippet in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": "Final Answer",
|
||||
"action_input": string \\ You should put what you want to return to use here
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
USER_INPUT = """
|
||||
USER'S INPUT
|
||||
--------------------
|
||||
Here is the user's input \
|
||||
(remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
|
||||
|
||||
{user_input}
|
||||
"""
|
||||
|
||||
TOOL_FOLLOWUP = """
|
||||
TOOL RESPONSE:
|
||||
---------------------
|
||||
{tool_output}
|
||||
|
||||
USER'S INPUT
|
||||
--------------------
|
||||
Okay, so what is the response to my last comment? If using information obtained from the tools you must \
|
||||
mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES!
|
||||
If the tool response is not useful, ignore it completely.
|
||||
{optional_reminder}{hint}
|
||||
IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else.
|
||||
"""
|
||||
|
||||
|
||||
def form_user_prompt_text(
|
||||
query: str,
|
||||
tool_text: str | None,
|
||||
hint_text: str | None,
|
||||
user_input_prompt: str = USER_INPUT,
|
||||
tool_less_prompt: str = TOOL_LESS_PROMPT,
|
||||
) -> str:
|
||||
user_prompt = tool_text or tool_less_prompt
|
||||
|
||||
user_prompt += user_input_prompt.format(user_input=query)
|
||||
|
||||
if hint_text:
|
||||
if user_prompt[-1] != "\n":
|
||||
user_prompt += "\n"
|
||||
user_prompt += "\nHint: " + hint_text
|
||||
|
||||
return user_prompt.strip()
|
||||
|
||||
|
||||
def form_tool_section_text(
|
||||
tools: list[dict[str, str]], retrieval_enabled: bool, template: str = TOOL_TEMPLATE
|
||||
) -> str | None:
|
||||
if not tools and not retrieval_enabled:
|
||||
return None
|
||||
|
||||
if retrieval_enabled:
|
||||
tools.append(
|
||||
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
|
||||
)
|
||||
|
||||
tools_intro = []
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = ", ".join([tool["name"] for tool in tools])
|
||||
|
||||
return template.format(
|
||||
tool_overviews=tools_intro_text, tool_names=tool_names_text
|
||||
).strip()
|
||||
|
||||
|
||||
def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
|
||||
return "\n".join(
|
||||
f"DOCUMENT {ind}:{CODE_BLOCK_PAT.format(chunk.content)}"
|
||||
for ind, chunk in enumerate(chunks, start=1)
|
||||
)
|
||||
|
||||
|
||||
def form_tool_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_FOLLOWUP,
|
||||
ignore_hint: bool = False,
|
||||
) -> str:
|
||||
# If multi-line query, it likely confuses the model more than helps
|
||||
if "\n" not in query:
|
||||
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
|
||||
else:
|
||||
optional_reminder = ""
|
||||
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"\nHint: {hint_text}\n"
|
||||
else:
|
||||
hint_text_spaced = ""
|
||||
|
||||
return tool_followup_prompt.format(
|
||||
tool_output=tool_output,
|
||||
optional_reminder=optional_reminder,
|
||||
hint=hint_text_spaced,
|
||||
).strip()
|
||||
|
||||
|
||||
def build_combined_query(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
) -> list[BaseMessage]:
|
||||
user_query = query_message.message
|
||||
combined_query_msgs: list[BaseMessage] = []
|
||||
|
||||
if not user_query:
|
||||
raise ValueError("Can't rephrase/search an empty query")
|
||||
|
||||
combined_query_msgs.append(SystemMessage(content=DANSWER_SYSTEM_MSG))
|
||||
|
||||
combined_query_msgs.extend(
|
||||
[translate_danswer_msg_to_langchain(msg) for msg in history]
|
||||
)
|
||||
|
||||
combined_query_msgs.append(
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Help me rewrite this final query into a standalone question that takes into consideration the "
|
||||
f"past messages of the conversation. You must ONLY return the rewritten query and nothing else."
|
||||
f"\n\nQuery:\n{query_message.message}"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return combined_query_msgs
|
||||
611
backend/danswer/chat/chat_utils.py
Normal file
611
backend/danswer/chat/chat_utils.py
Normal file
@@ -0,0 +1,611 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from typing import cast
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
|
||||
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
||||
from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT
|
||||
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
|
||||
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
||||
from danswer.prompts.token_counts import (
|
||||
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
|
||||
)
|
||||
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||
|
||||
# Maps connector enum string to a more natural language representation for the LLM
|
||||
# If not on the list, uses the original but slightly cleaned up, see below
|
||||
CONNECTOR_NAME_MAP = {
|
||||
"web": "Website",
|
||||
"requesttracker": "Request Tracker",
|
||||
"github": "GitHub",
|
||||
"file": "File Upload",
|
||||
}
|
||||
|
||||
|
||||
def clean_up_source(source_str: str) -> str:
|
||||
if source_str in CONNECTOR_NAME_MAP:
|
||||
return CONNECTOR_NAME_MAP[source_str]
|
||||
return source_str.replace("_", " ").title()
|
||||
|
||||
|
||||
def build_doc_context_str(
|
||||
semantic_identifier: str,
|
||||
source_type: DocumentSource,
|
||||
content: str,
|
||||
metadata_dict: dict[str, str | list[str]],
|
||||
updated_at: datetime | None,
|
||||
ind: int,
|
||||
include_metadata: bool = True,
|
||||
) -> str:
|
||||
context_str = ""
|
||||
if include_metadata:
|
||||
context_str += f"DOCUMENT {ind}: {semantic_identifier}\n"
|
||||
context_str += f"Source: {clean_up_source(source_type)}\n"
|
||||
|
||||
for k, v in metadata_dict.items():
|
||||
if isinstance(v, list):
|
||||
v_str = ", ".join(v)
|
||||
context_str += f"{k.capitalize()}: {v_str}\n"
|
||||
else:
|
||||
context_str += f"{k.capitalize()}: {v}\n"
|
||||
|
||||
if updated_at:
|
||||
update_str = updated_at.strftime("%B %d, %Y %H:%M")
|
||||
context_str += f"Updated: {update_str}\n"
|
||||
context_str += f"{CODE_BLOCK_PAT.format(content.strip())}\n\n\n"
|
||||
return context_str
|
||||
|
||||
|
||||
def build_complete_context_str(
|
||||
context_docs: list[LlmDoc | InferenceChunk],
|
||||
include_metadata: bool = True,
|
||||
) -> str:
|
||||
context_str = ""
|
||||
for ind, doc in enumerate(context_docs, start=1):
|
||||
context_str += build_doc_context_str(
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
source_type=doc.source_type,
|
||||
content=doc.content,
|
||||
metadata_dict=doc.metadata,
|
||||
updated_at=doc.updated_at,
|
||||
ind=ind,
|
||||
include_metadata=include_metadata,
|
||||
)
|
||||
|
||||
return context_str.strip()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def build_chat_system_message(
|
||||
prompt: Prompt,
|
||||
context_exists: bool,
|
||||
llm_tokenizer_encode_func: Callable,
|
||||
citation_line: str = REQUIRE_CITATION_STATEMENT,
|
||||
no_citation_line: str = NO_CITATION_STATEMENT,
|
||||
) -> tuple[SystemMessage | None, int]:
|
||||
system_prompt = prompt.system_prompt.strip()
|
||||
if prompt.include_citations:
|
||||
if context_exists:
|
||||
system_prompt += citation_line
|
||||
else:
|
||||
system_prompt += no_citation_line
|
||||
if prompt.datetime_aware:
|
||||
if system_prompt:
|
||||
system_prompt += (
|
||||
f"\n\nAdditional Information:\n\t- {get_current_llm_day_time()}."
|
||||
)
|
||||
else:
|
||||
system_prompt = get_current_llm_day_time()
|
||||
|
||||
if not system_prompt:
|
||||
return None, 0
|
||||
|
||||
token_count = len(llm_tokenizer_encode_func(system_prompt))
|
||||
system_msg = SystemMessage(content=system_prompt)
|
||||
|
||||
return system_msg, token_count
|
||||
|
||||
|
||||
def build_task_prompt_reminders(
|
||||
prompt: Prompt,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
citation_str: str = CITATION_REMINDER,
|
||||
language_hint_str: str = LANGUAGE_HINT,
|
||||
) -> str:
|
||||
base_task = prompt.task_prompt
|
||||
citation_or_nothing = citation_str if prompt.include_citations else ""
|
||||
language_hint_or_nothing = language_hint_str.lstrip() if use_language_hint else ""
|
||||
return base_task + citation_or_nothing + language_hint_or_nothing
|
||||
|
||||
|
||||
def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inf_chunk.document_id,
|
||||
content=inf_chunk.content,
|
||||
semantic_identifier=inf_chunk.semantic_identifier,
|
||||
source_type=inf_chunk.source_type,
|
||||
metadata=inf_chunk.metadata,
|
||||
updated_at=inf_chunk.updated_at,
|
||||
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
|
||||
)
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: list[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> dict[str, int]:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
for chunk in chunks:
|
||||
if chunk.document_id not in order_mapping:
|
||||
order_mapping[chunk.document_id] = current
|
||||
current += 1
|
||||
|
||||
return order_mapping
|
||||
|
||||
|
||||
def build_chat_user_message(
|
||||
chat_message: ChatMessage,
|
||||
prompt: Prompt,
|
||||
context_docs: list[LlmDoc],
|
||||
llm_tokenizer_encode_func: Callable,
|
||||
all_doc_useful: bool,
|
||||
user_prompt_template: str = CHAT_USER_PROMPT,
|
||||
context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT,
|
||||
ignore_str: str = DEFAULT_IGNORE_STATEMENT,
|
||||
) -> tuple[HumanMessage, int]:
|
||||
user_query = chat_message.message
|
||||
|
||||
if not context_docs:
|
||||
# Simpler prompt for cases where there is no context
|
||||
user_prompt = (
|
||||
context_free_template.format(
|
||||
task_prompt=prompt.task_prompt, user_query=user_query
|
||||
)
|
||||
if prompt.task_prompt
|
||||
else user_query
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
token_count = len(llm_tokenizer_encode_func(user_prompt))
|
||||
user_msg = HumanMessage(content=user_prompt)
|
||||
return user_msg, token_count
|
||||
|
||||
context_docs_str = build_complete_context_str(
|
||||
cast(list[LlmDoc | InferenceChunk], context_docs)
|
||||
)
|
||||
optional_ignore = "" if all_doc_useful else ignore_str
|
||||
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
|
||||
|
||||
user_prompt = user_prompt_template.format(
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=user_query,
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
token_count = len(llm_tokenizer_encode_func(user_prompt))
|
||||
user_msg = HumanMessage(content=user_prompt)
|
||||
|
||||
return user_msg, token_count
|
||||
|
||||
|
||||
def _get_usable_chunks(
|
||||
chunks: list[InferenceChunk], token_limit: int
|
||||
) -> list[InferenceChunk]:
|
||||
total_token_count = 0
|
||||
usable_chunks = []
|
||||
for chunk in chunks:
|
||||
chunk_token_count = check_number_of_tokens(chunk.content)
|
||||
if total_token_count + chunk_token_count > token_limit:
|
||||
break
|
||||
|
||||
total_token_count += chunk_token_count
|
||||
usable_chunks.append(chunk)
|
||||
|
||||
# try and return at least one chunk if possible. This chunk will
|
||||
# get truncated later on in the pipeline. This would only occur if
|
||||
# the first chunk is larger than the token limit (usually due to character
|
||||
# count -> token count mismatches caused by special characters / non-ascii
|
||||
# languages)
|
||||
if not usable_chunks and chunks:
|
||||
usable_chunks = [chunks[0]]
|
||||
|
||||
return usable_chunks
|
||||
|
||||
|
||||
def get_usable_chunks(
|
||||
chunks: list[InferenceChunk],
|
||||
token_limit: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
offset_into_chunks = 0
|
||||
usable_chunks: list[InferenceChunk] = []
|
||||
for _ in range(min(offset + 1, 1)): # go through this process at least once
|
||||
if offset_into_chunks >= len(chunks) and offset_into_chunks > 0:
|
||||
raise ValueError(
|
||||
"Chunks offset too large, should not retry this many times"
|
||||
)
|
||||
|
||||
usable_chunks = _get_usable_chunks(
|
||||
chunks=chunks[offset_into_chunks:], token_limit=token_limit
|
||||
)
|
||||
offset_into_chunks += len(usable_chunks)
|
||||
|
||||
return usable_chunks
|
||||
|
||||
|
||||
def get_chunks_for_qa(
|
||||
chunks: list[InferenceChunk],
|
||||
llm_chunk_selection: list[bool],
|
||||
token_limit: int | None,
|
||||
batch_offset: int = 0,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Gives back indices of chunks to pass into the LLM for Q&A.
|
||||
|
||||
Only selects chunks viable for Q&A, within the token limit, and prioritize those selected
|
||||
by the LLM in a separate flow (this can be turned off)
|
||||
|
||||
Note, the batch_offset calculation has to count the batches from the beginning each time as
|
||||
there's no way to know which chunks were included in the prior batches without recounting atm,
|
||||
this is somewhat slow as it requires tokenizing all the chunks again
|
||||
"""
|
||||
batch_index = 0
|
||||
latest_batch_indices: list[int] = []
|
||||
token_count = 0
|
||||
|
||||
# First iterate the LLM selected chunks, then iterate the rest if tokens remaining
|
||||
for selection_target in [True, False]:
|
||||
for ind, chunk in enumerate(chunks):
|
||||
if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get(
|
||||
IGNORE_FOR_QA
|
||||
):
|
||||
continue
|
||||
|
||||
# We calculate it live in case the user uses a different LLM + tokenizer
|
||||
chunk_token = check_number_of_tokens(chunk.content)
|
||||
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
|
||||
token_count += chunk_token + 50
|
||||
|
||||
# Always use at least 1 chunk
|
||||
if (
|
||||
token_limit is None
|
||||
or token_count <= token_limit
|
||||
or not latest_batch_indices
|
||||
):
|
||||
latest_batch_indices.append(ind)
|
||||
current_chunk_unused = False
|
||||
else:
|
||||
current_chunk_unused = True
|
||||
|
||||
if token_limit is not None and token_count >= token_limit:
|
||||
if batch_index < batch_offset:
|
||||
batch_index += 1
|
||||
if current_chunk_unused:
|
||||
latest_batch_indices = [ind]
|
||||
token_count = chunk_token
|
||||
else:
|
||||
latest_batch_indices = []
|
||||
token_count = 0
|
||||
else:
|
||||
return latest_batch_indices
|
||||
|
||||
return latest_batch_indices
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
all_chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
)
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
if not all_chat_messages:
|
||||
raise ValueError("No messages in Chat Session")
|
||||
|
||||
root_message = all_chat_messages[0]
|
||||
if root_message.parent_message is not None:
|
||||
raise RuntimeError(
|
||||
"Invalid root message, unable to fetch valid chat message sequence"
|
||||
)
|
||||
|
||||
current_message: ChatMessage | None = root_message
|
||||
while current_message is not None:
|
||||
child_msg = current_message.latest_child_message
|
||||
if not child_msg:
|
||||
break
|
||||
current_message = id_to_msg.get(child_msg)
|
||||
|
||||
if current_message is None:
|
||||
raise RuntimeError(
|
||||
"Invalid message chain,"
|
||||
"could not find next message in the same session"
|
||||
)
|
||||
|
||||
mainline_messages.append(current_message)
|
||||
|
||||
if not mainline_messages:
|
||||
raise RuntimeError("Could not trace chat message history")
|
||||
|
||||
return mainline_messages[-1], mainline_messages[:-1]
|
||||
|
||||
|
||||
def combine_message_chain(
|
||||
messages: list[ChatMessage],
|
||||
token_limit: int,
|
||||
msg_limit: int | None = None,
|
||||
) -> str:
|
||||
"""Used for secondary LLM flows that require the chat history,"""
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
if msg_limit is not None:
|
||||
messages = messages[-msg_limit:]
|
||||
|
||||
for message in reversed(messages):
|
||||
message_token_count = message.token_count
|
||||
|
||||
if total_token_count + message_token_count > token_limit:
|
||||
break
|
||||
|
||||
role = message.message_type.value.upper()
|
||||
message_strs.insert(0, f"{role}:\n{message.message}")
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
_PER_MESSAGE_TOKEN_BUFFER = 7
|
||||
|
||||
|
||||
def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
|
||||
"""From the back, find the index of the last element to include
|
||||
before the list exceeds the maximum"""
|
||||
running_sum = 0
|
||||
|
||||
last_ind = 0
|
||||
for i in range(len(lst) - 1, -1, -1):
|
||||
running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER
|
||||
if running_sum > max_prompt_tokens:
|
||||
last_ind = i + 1
|
||||
break
|
||||
if last_ind >= len(lst):
|
||||
raise ValueError("Last message alone is too large!")
|
||||
return last_ind
|
||||
|
||||
|
||||
def drop_messages_history_overflow(
|
||||
system_msg: BaseMessage | None,
|
||||
system_token_count: int,
|
||||
history_msgs: list[BaseMessage],
|
||||
history_token_counts: list[int],
|
||||
final_msg: BaseMessage,
|
||||
final_msg_token_count: int,
|
||||
max_allowed_tokens: int,
|
||||
) -> list[BaseMessage]:
|
||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||
prompt template must be included"""
|
||||
if len(history_msgs) != len(history_token_counts):
|
||||
# This should never happen
|
||||
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
||||
|
||||
prompt: list[BaseMessage] = []
|
||||
|
||||
# Start dropping from the history if necessary
|
||||
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
||||
ind_prev_msg_start = find_last_index(
|
||||
all_tokens, max_prompt_tokens=max_allowed_tokens
|
||||
)
|
||||
|
||||
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
||||
prompt.append(system_msg)
|
||||
|
||||
prompt.extend(history_msgs[ind_prev_msg_start:])
|
||||
|
||||
prompt.append(final_msg)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def in_code_block(llm_text: str) -> bool:
|
||||
count = llm_text.count(TRIPLE_BACKTICK)
|
||||
return count % 2 != 0
|
||||
|
||||
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: dict[str, int],
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||
llm_out = ""
|
||||
max_citation_num = len(context_docs)
|
||||
curr_segment = ""
|
||||
prepend_bracket = False
|
||||
cited_inds = set()
|
||||
hold = ""
|
||||
for raw_token in tokens:
|
||||
if stop_stream:
|
||||
next_hold = hold + raw_token
|
||||
|
||||
if stop_stream in next_hold:
|
||||
break
|
||||
|
||||
if next_hold == stop_stream[: len(next_hold)]:
|
||||
hold = next_hold
|
||||
continue
|
||||
|
||||
token = next_hold
|
||||
hold = ""
|
||||
else:
|
||||
token = raw_token
|
||||
|
||||
# Special case of [1][ where ][ is a single token
|
||||
# This is where the model attempts to do consecutive citations like [1][2]
|
||||
if prepend_bracket:
|
||||
curr_segment += "[" + curr_segment
|
||||
prepend_bracket = False
|
||||
|
||||
curr_segment += token
|
||||
llm_out += token
|
||||
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
|
||||
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||
citation_found = re.search(citation_pattern, curr_segment)
|
||||
|
||||
if citation_found and not in_code_block(llm_out):
|
||||
numerical_value = int(citation_found.group(1))
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[
|
||||
numerical_value - 1
|
||||
] # remove 1 index offset
|
||||
|
||||
link = context_llm_doc.link
|
||||
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
|
||||
|
||||
# Use the citation number for the document's rank in
|
||||
# the search (or selected docs) results
|
||||
curr_segment = re.sub(
|
||||
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
|
||||
)
|
||||
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
if link:
|
||||
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
||||
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
||||
|
||||
# In case there's another open bracket like [1][, don't want to match this
|
||||
possible_citation_found = None
|
||||
|
||||
# if we see "[", but haven't seen the right side, hold back - this may be a
|
||||
# citation that needs to be replaced with a link
|
||||
if possible_citation_found:
|
||||
continue
|
||||
|
||||
# Special case with back to back citations [1][2]
|
||||
if curr_segment and curr_segment[-1] == "[":
|
||||
curr_segment = curr_segment[:-1]
|
||||
prepend_bracket = True
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
curr_segment = ""
|
||||
|
||||
if curr_segment:
|
||||
if prepend_bracket:
|
||||
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
|
||||
else:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
|
||||
|
||||
def get_prompt_tokens(prompt: Prompt) -> int:
|
||||
return (
|
||||
check_number_of_tokens(prompt.system_prompt)
|
||||
+ check_number_of_tokens(prompt.task_prompt)
|
||||
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
|
||||
+ CITATION_STATEMENT_TOKEN_CNT
|
||||
+ CITATION_REMINDER_TOKEN_CNT
|
||||
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
|
||||
)
|
||||
|
||||
|
||||
# buffer just to be safe so that we don't overflow the token limit due to
|
||||
# a small miscalculation
|
||||
_MISC_BUFFER = 40
|
||||
|
||||
|
||||
def compute_max_document_tokens(
|
||||
persona: Persona,
|
||||
actual_user_input: str | None = None,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
"""Estimates the number of tokens available for context documents. Formula is roughly:
|
||||
|
||||
(
|
||||
model_context_window - reserved_output_tokens - prompt_tokens
|
||||
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
|
||||
)
|
||||
|
||||
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
|
||||
if we're trying to determine if the user should be able to select another document) then we just set an
|
||||
arbitrary "upper bound".
|
||||
"""
|
||||
llm_name = GEN_AI_MODEL_VERSION
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
# if we can't find a number of tokens, just assume some common default
|
||||
max_input_tokens = (
|
||||
max_llm_token_override
|
||||
if max_llm_token_override
|
||||
else get_max_input_tokens(model_name=llm_name)
|
||||
)
|
||||
if persona.prompts:
|
||||
# TODO this may not always be the first prompt
|
||||
prompt_tokens = get_prompt_tokens(persona.prompts[0])
|
||||
else:
|
||||
raise RuntimeError("Persona has no prompts - this should never happen")
|
||||
user_input_tokens = (
|
||||
check_number_of_tokens(actual_user_input)
|
||||
if actual_user_input is not None
|
||||
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
)
|
||||
|
||||
return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER
|
||||
|
||||
|
||||
def compute_max_llm_input_tokens(persona: Persona) -> int:
|
||||
"""Maximum tokens allows in the input to the LLM (of any type)."""
|
||||
llm_name = GEN_AI_MODEL_VERSION
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
input_tokens = get_max_input_tokens(model_name=llm_name)
|
||||
return input_tokens - _MISC_BUFFER
|
||||
106
backend/danswer/chat/load_yamls.py
Normal file
106
backend/danswer/chat/load_yamls.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from typing import cast
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.db.chat import get_prompt_by_name
|
||||
from danswer.db.chat import upsert_persona
|
||||
from danswer.db.chat import upsert_prompt
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Prompt as PromptDBModel
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
|
||||
|
||||
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
||||
with open(prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_prompts = data.get("prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for prompt in all_prompts:
|
||||
upsert_prompt(
|
||||
user_id=None,
|
||||
prompt_id=prompt.get("id"),
|
||||
name=prompt["name"],
|
||||
description=prompt["description"].strip(),
|
||||
system_prompt=prompt["system"].strip(),
|
||||
task_prompt=prompt["task"].strip(),
|
||||
include_citations=prompt["include_citations"],
|
||||
datetime_aware=prompt.get("datetime_aware", True),
|
||||
default_prompt=True,
|
||||
personas=None,
|
||||
shared=True,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_personas_from_yaml(
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
) -> None:
|
||||
with open(personas_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_personas = data.get("personas", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] | None = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
]
|
||||
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
if not doc_sets:
|
||||
doc_sets = None
|
||||
|
||||
prompt_set_names = persona["prompts"]
|
||||
if not prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] | None = None
|
||||
else:
|
||||
prompts = [
|
||||
get_prompt_by_name(
|
||||
prompt_name, user_id=None, shared=True, db_session=db_session
|
||||
)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
if not prompts:
|
||||
prompts = None
|
||||
|
||||
upsert_persona(
|
||||
user_id=None,
|
||||
persona_id=persona.get("id"),
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
num_chunks=persona.get("num_chunks")
|
||||
if persona.get("num_chunks") is not None
|
||||
else default_chunks,
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompts=cast(list[PromptDBModel] | None, prompts),
|
||||
document_sets=doc_sets,
|
||||
default_persona=True,
|
||||
shared=True,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(prompt_yaml)
|
||||
load_personas_from_yaml(personas_yaml)
|
||||
110
backend/danswer/chat/models.py
Normal file
110
backend/danswer/chat/models.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
"""This contains the minimal set information for the LLM portion including citations"""
|
||||
|
||||
document_id: str
|
||||
content: str
|
||||
semantic_identifier: str
|
||||
source_type: DocumentSource
|
||||
metadata: dict[str, str | list[str]]
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs):
|
||||
rephrased_query: str | None = None
|
||||
predicted_flow: QueryFlow | None
|
||||
predicted_search: SearchType | None
|
||||
applied_source_filters: list[DocumentSource] | None
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
initial_dict["applied_time_cutoff"] = (
|
||||
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
||||
)
|
||||
return initial_dict
|
||||
|
||||
|
||||
# Second chunk of info for streaming QA
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
relevant_chunk_indices: list[int]
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
|
||||
|
||||
# An intermediate representation of citations, later translated into
|
||||
# a mapping of the citation [n] number to SearchDoc
|
||||
class CitationInfo(BaseModel):
|
||||
citation_num: int
|
||||
document_id: str
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
class DanswerContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerContexts(BaseModel):
|
||||
contexts: list[DanswerContext]
|
||||
|
||||
|
||||
class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class QAResponse(SearchResponse, DanswerAnswer):
|
||||
quotes: list[DanswerQuote] | None
|
||||
contexts: list[DanswerContexts] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
AnswerQuestionStreamReturn = Iterator[
|
||||
DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError
|
||||
]
|
||||
|
||||
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
@@ -1,29 +0,0 @@
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_prompts import form_tool_section_text
|
||||
from danswer.configs.app_configs import PERSONAS_YAML
|
||||
from danswer.db.chat import create_persona
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
|
||||
|
||||
def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
|
||||
with open(personas_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_personas = data.get("personas", [])
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||
for persona in all_personas:
|
||||
tools = form_tool_section_text(
|
||||
persona["tools"], persona["retrieval_enabled"]
|
||||
)
|
||||
create_persona(
|
||||
persona_id=persona["id"],
|
||||
name=persona["name"],
|
||||
retrieval_enabled=persona["retrieval_enabled"],
|
||||
system_text=persona["system"],
|
||||
tools_text=tools,
|
||||
hint_text=persona["hint"],
|
||||
default_persona=True,
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -1,13 +1,64 @@
|
||||
# Currently in the UI, each Persona only has one prompt, which is why there are 3 very similar personas defined below.
|
||||
|
||||
personas:
|
||||
- id: 1
|
||||
name: "Danswer"
|
||||
system: |
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
|
||||
Your responses are as INFORMATIVE and DETAILED as possible.
|
||||
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
|
||||
retrieval_enabled: true
|
||||
# Each added tool needs to have a "name" and "description"
|
||||
tools: []
|
||||
# Short tip to pass near the end of the prompt to emphasize some requirement
|
||||
hint: "Try to be as informative as possible!"
|
||||
# This id field can be left blank for other default personas, however an id 0 persona must exist
|
||||
# this is for DanswerBot to use when tagged in a non-configured channel
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Default"
|
||||
description: >
|
||||
Default Danswer Question Answering functionality.
|
||||
# Default Prompt objects attached to the persona, see prompts.yaml
|
||||
prompts:
|
||||
- "Answer-Question"
|
||||
# Default number of chunks to include as context, set to 0 to disable retrieval
|
||||
# Remove the field to set to the system default number of chunks/tokens to pass to Gen AI
|
||||
# Each chunk is 512 tokens long
|
||||
num_chunks: 10
|
||||
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
|
||||
# if the chunk is useful or not towards the latest user query
|
||||
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
|
||||
llm_relevance_filter: true
|
||||
# Enable/Disable usage of the LLM to extract query time filters including source type and time range filters
|
||||
llm_filter_extraction: true
|
||||
# Decay documents priority as they age, options are:
|
||||
# - favor_recent (2x base by default, configurable)
|
||||
# - base_decay
|
||||
# - no_decay
|
||||
# - auto (model chooses between favor_recent and base_decay based on user query)
|
||||
recency_bias: "auto"
|
||||
# Default Document Sets for this persona, specified as a list of names here.
|
||||
# If the document set by the name exists, it will be attached to the persona
|
||||
# If the document set by the name does not exist, it will be created as an empty document set with no connectors
|
||||
# The admin can then use the UI to add new connectors to the document set
|
||||
# Example:
|
||||
# document_sets:
|
||||
# - "HR Resources"
|
||||
# - "Engineer Onboarding"
|
||||
# - "Benefits"
|
||||
document_sets: []
|
||||
|
||||
|
||||
- name: "Summarize"
|
||||
description: >
|
||||
A less creative assistant which summarizes relevant documents but does not try to
|
||||
extrapolate any answers for you.
|
||||
prompts:
|
||||
- "Summarize"
|
||||
num_chunks: 10
|
||||
llm_relevance_filter: true
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
|
||||
|
||||
- name: "Paraphrase"
|
||||
description: >
|
||||
The least creative default assistant that only provides quotes from the documents.
|
||||
prompts:
|
||||
- "Paraphrase"
|
||||
num_chunks: 10
|
||||
llm_relevance_filter: true
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
|
||||
574
backend/danswer/chat/process_message.py
Normal file
574
backend/danswer/chat/process_message.py
Normal file
@@ -0,0 +1,574 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import build_chat_system_message
|
||||
from danswer.chat.chat_utils import build_chat_user_message
|
||||
from danswer.chat.chat_utils import build_doc_context_str
|
||||
from danswer.chat.chat_utils import compute_max_document_tokens
|
||||
from danswer.chat.chat_utils import compute_max_llm_input_tokens
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import drop_messages_history_overflow
|
||||
from danswer.chat.chat_utils import extract_citations_from_stream
|
||||
from danswer.chat.chat_utils import get_chunks_for_qa
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
|
||||
from danswer.chat.chat_utils import map_document_id_order
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import DISABLED_GEN_AI_MSG
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import CHUNK_SIZE
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_db_search_doc_by_id
|
||||
from danswer.db.chat import get_doc_query_identifiers_from_model
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.request_preprocessing import retrieval_preprocessing
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import full_chunk_search_generator
|
||||
from danswer.search.search_runner import inference_documents_from_ids
|
||||
from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_ai_chat_response(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
persona: Persona,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: dict[str, int],
|
||||
llm: LLM | None,
|
||||
llm_tokenizer_encode_func: Callable,
|
||||
all_doc_useful: bool,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
|
||||
if llm is None:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
# Not an error if it's a user configuration
|
||||
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
|
||||
return
|
||||
|
||||
if query_message.prompt is None:
|
||||
raise RuntimeError("No prompt received for generating Gen AI answer.")
|
||||
|
||||
try:
|
||||
context_exists = len(context_docs) > 0
|
||||
|
||||
system_message_or_none, system_tokens = build_chat_system_message(
|
||||
prompt=query_message.prompt,
|
||||
context_exists=context_exists,
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
)
|
||||
|
||||
history_basemessages, history_token_counts = translate_history_to_basemessages(
|
||||
history
|
||||
)
|
||||
|
||||
# Be sure the context_docs passed to build_chat_user_message
|
||||
# Is the same as passed in later for extracting citations
|
||||
user_message, user_tokens = build_chat_user_message(
|
||||
chat_message=query_message,
|
||||
prompt=query_message.prompt,
|
||||
context_docs=context_docs,
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
all_doc_useful=all_doc_useful,
|
||||
)
|
||||
|
||||
prompt = drop_messages_history_overflow(
|
||||
system_msg=system_message_or_none,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=history_basemessages,
|
||||
history_token_counts=history_token_counts,
|
||||
final_msg=user_message,
|
||||
final_msg_token_count=user_tokens,
|
||||
max_allowed_tokens=compute_max_llm_input_tokens(persona),
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
|
||||
yield from extract_citations_from_stream(
|
||||
tokens, context_docs, doc_id_to_rank_map
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield StreamingError(error=str(e))
|
||||
|
||||
|
||||
def translate_citations(
|
||||
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
||||
) -> dict[int, int]:
|
||||
"""Always cites the first instance of the document_id, assumes the db_docs
|
||||
are sorted in the order displayed in the UI"""
|
||||
doc_id_to_saved_doc_id_map: dict[str, int] = {}
|
||||
for db_doc in db_docs:
|
||||
if db_doc.document_id not in doc_id_to_saved_doc_id_map:
|
||||
doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id
|
||||
|
||||
citation_to_saved_doc_id_map: dict[int, int] = {}
|
||||
for citation in citations_list:
|
||||
if citation.citation_num not in citation_to_saved_doc_id_map:
|
||||
citation_to_saved_doc_id_map[
|
||||
citation.citation_num
|
||||
] = doc_id_to_saved_doc_id_map[citation.document_id]
|
||||
|
||||
return citation_to_saved_doc_id_map
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
default_chunk_size: int = CHUNK_SIZE,
|
||||
# For flow with search, don't include as many chunks as possible since we need to leave space
|
||||
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
) -> Iterator[str]:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
|
||||
"""
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
message_text = new_msg_req.message
|
||||
chat_session_id = new_msg_req.chat_session_id
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
persona = chat_session.persona
|
||||
query_override = new_msg_req.query_override
|
||||
|
||||
if reference_doc_ids is None and retrieval_options is None:
|
||||
raise RuntimeError(
|
||||
"Must specify a set of documents for chat or specify search options"
|
||||
)
|
||||
|
||||
try:
|
||||
llm = get_default_llm(
|
||||
gen_ai_model_version_override=persona.llm_model_version_override
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
llm = None
|
||||
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if parent_id is not None:
|
||||
parent_message = get_chat_message(
|
||||
chat_message_id=parent_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
parent_message = root_message
|
||||
|
||||
# Create new message at the right place in the tree and update the parent's child pointer
|
||||
# Don't commit yet until we verify the chat message chain
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message,
|
||||
prompt_id=prompt_id,
|
||||
message=message_text,
|
||||
token_count=len(llm_tokenizer_encode_func(message_text)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
# Create linear history of messages
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if final_msg.id != new_user_message.id:
|
||||
db_session.rollback()
|
||||
raise RuntimeError(
|
||||
"The new message was not on the mainline. "
|
||||
"Be sure to update the chat pointers before calling this."
|
||||
)
|
||||
|
||||
# Save now to save the latest chat message
|
||||
db_session.commit()
|
||||
|
||||
run_search = False
|
||||
# Retrieval options are only None if reference_doc_ids are provided
|
||||
if retrieval_options is not None and persona.num_chunks != 0:
|
||||
if retrieval_options.run_search == OptionalSearchSetting.ALWAYS:
|
||||
run_search = True
|
||||
elif retrieval_options.run_search == OptionalSearchSetting.NEVER:
|
||||
run_search = False
|
||||
else:
|
||||
run_search = check_if_need_search(
|
||||
query_message=final_msg, history=history_msgs, llm=llm
|
||||
)
|
||||
|
||||
max_document_tokens = compute_max_document_tokens(
|
||||
persona=persona, actual_user_input=message_text
|
||||
)
|
||||
|
||||
rephrased_query = None
|
||||
if reference_doc_ids:
|
||||
identifier_tuples = get_doc_query_identifiers_from_model(
|
||||
search_doc_ids=reference_doc_ids,
|
||||
chat_session=chat_session,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Generates full documents currently
|
||||
# May extend to include chunk ranges
|
||||
llm_docs: list[LlmDoc] = inference_documents_from_ids(
|
||||
doc_identifiers=identifier_tuples,
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
# truncate the last document if it exceeds the token limit
|
||||
tokens_per_doc = [
|
||||
len(
|
||||
llm_tokenizer_encode_func(
|
||||
build_doc_context_str(
|
||||
semantic_identifier=llm_doc.semantic_identifier,
|
||||
source_type=llm_doc.source_type,
|
||||
content=llm_doc.content,
|
||||
metadata_dict=llm_doc.metadata,
|
||||
updated_at=llm_doc.updated_at,
|
||||
ind=ind,
|
||||
)
|
||||
)
|
||||
)
|
||||
for ind, llm_doc in enumerate(llm_docs)
|
||||
]
|
||||
final_doc_ind = None
|
||||
total_tokens = 0
|
||||
for ind, tokens in enumerate(tokens_per_doc):
|
||||
total_tokens += tokens
|
||||
if total_tokens > max_document_tokens:
|
||||
final_doc_ind = ind
|
||||
break
|
||||
if final_doc_ind is not None:
|
||||
# only allow the final document to get truncated
|
||||
# if more than that, then the user message is too long
|
||||
if final_doc_ind != len(tokens_per_doc) - 1:
|
||||
yield get_json_line(
|
||||
StreamingError(
|
||||
error="LLM context window exceeded. Please de-select some documents or shorten your query."
|
||||
).dict()
|
||||
)
|
||||
return
|
||||
|
||||
final_doc_desired_length = tokens_per_doc[final_doc_ind] - (
|
||||
total_tokens - max_document_tokens
|
||||
)
|
||||
# 75 tokens is a reasonable over-estimate of the metadata and title
|
||||
final_doc_content_length = final_doc_desired_length - 75
|
||||
# this could occur if we only have space for the title / metadata
|
||||
# not ideal, but it's the most reasonable thing to do
|
||||
# NOTE: the frontend prevents documents from being selected if
|
||||
# less than 75 tokens are available to try and avoid this situation
|
||||
# from occuring in the first place
|
||||
if final_doc_content_length <= 0:
|
||||
logger.error(
|
||||
f"Final doc ({llm_docs[final_doc_ind].semantic_identifier}) content "
|
||||
"length is less than 0. Removing this doc from the final prompt."
|
||||
)
|
||||
llm_docs.pop()
|
||||
else:
|
||||
llm_docs[final_doc_ind].content = tokenizer_trim_content(
|
||||
content=llm_docs[final_doc_ind].content,
|
||||
desired_length=final_doc_content_length,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
doc_id_to_rank_map = map_document_id_order(
|
||||
cast(list[InferenceChunk | LlmDoc], llm_docs)
|
||||
)
|
||||
|
||||
# In case the search doc is deleted, just don't include it
|
||||
# though this should never happen
|
||||
db_search_docs_or_none = [
|
||||
get_db_search_doc_by_id(doc_id=doc_id, db_session=db_session)
|
||||
for doc_id in reference_doc_ids
|
||||
]
|
||||
|
||||
reference_db_search_docs = [
|
||||
db_sd for db_sd in db_search_docs_or_none if db_sd
|
||||
]
|
||||
|
||||
elif run_search:
|
||||
rephrased_query = (
|
||||
history_based_query_rephrase(
|
||||
query_message=final_msg, history=history_msgs, llm=llm
|
||||
)
|
||||
if query_override is None
|
||||
else query_override
|
||||
)
|
||||
|
||||
(
|
||||
retrieval_request,
|
||||
predicted_search_type,
|
||||
predicted_flow,
|
||||
) = retrieval_preprocessing(
|
||||
query=rephrased_query,
|
||||
retrieval_details=cast(RetrievalDetails, retrieval_options),
|
||||
persona=persona,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
documents_generator = full_chunk_search_generator(
|
||||
search_query=retrieval_request,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
)
|
||||
time_cutoff = retrieval_request.filters.time_cutoff
|
||||
recency_bias_multiplier = retrieval_request.recency_bias_multiplier
|
||||
run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter
|
||||
|
||||
# First fetch and return the top chunks to the UI so the user can
|
||||
# immediately see some results
|
||||
top_chunks = cast(list[InferenceChunk], next(documents_generator))
|
||||
|
||||
# Get ranking of the documents for citation purposes later
|
||||
doc_id_to_rank_map = map_document_id_order(
|
||||
cast(list[InferenceChunk | LlmDoc], top_chunks)
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
|
||||
for top_doc in top_docs
|
||||
]
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
initial_response = QADocsResponse(
|
||||
rephrased_query=rephrased_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search_type,
|
||||
applied_source_filters=retrieval_request.filters.source_type,
|
||||
applied_time_cutoff=time_cutoff,
|
||||
recency_bias_multiplier=recency_bias_multiplier,
|
||||
).dict()
|
||||
yield get_json_line(initial_response)
|
||||
|
||||
# Get the final ordering of chunks for the LLM call
|
||||
llm_chunk_selection = cast(list[bool], next(documents_generator))
|
||||
|
||||
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
|
||||
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=[
|
||||
index for index, value in enumerate(llm_chunk_selection) if value
|
||||
]
|
||||
if run_llm_chunk_filter
|
||||
else []
|
||||
).dict()
|
||||
yield get_json_line(llm_relevance_filtering_response)
|
||||
|
||||
# Prep chunks to pass to LLM
|
||||
num_llm_chunks = (
|
||||
persona.num_chunks
|
||||
if persona.num_chunks is not None
|
||||
else default_num_chunks
|
||||
)
|
||||
|
||||
llm_name = GEN_AI_MODEL_VERSION
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
llm_max_input_tokens = get_max_input_tokens(model_name=llm_name)
|
||||
|
||||
llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens
|
||||
|
||||
chunk_token_limit = int(
|
||||
min(
|
||||
num_llm_chunks * default_chunk_size,
|
||||
max_document_tokens,
|
||||
llm_token_based_chunk_lim,
|
||||
)
|
||||
)
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
token_limit=chunk_token_limit,
|
||||
)
|
||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]
|
||||
|
||||
else:
|
||||
llm_docs = []
|
||||
doc_id_to_rank_map = {}
|
||||
reference_db_search_docs = None
|
||||
|
||||
# Cannot determine these without the LLM step or breaking out early
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=prompt_id,
|
||||
# message=,
|
||||
rephrased_query=rephrased_query,
|
||||
# token_count=,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
# error=,
|
||||
reference_docs=reference_db_search_docs,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# If no prompt is provided, this is interpreted as not wanting an AI Answer
|
||||
# Simply provide/save the retrieval results
|
||||
if final_msg.prompt is None:
|
||||
gen_ai_response_message = partial_response(
|
||||
message="",
|
||||
token_count=0,
|
||||
citations=None,
|
||||
error=None,
|
||||
)
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield get_json_line(msg_detail_response.dict())
|
||||
|
||||
# Stop here after saving message details, the above still needs to be sent for the
|
||||
# message id to send the next follow-up message
|
||||
return
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
response_packets = generate_ai_chat_response(
|
||||
query_message=final_msg,
|
||||
history=history_msgs,
|
||||
persona=persona,
|
||||
context_docs=llm_docs,
|
||||
doc_id_to_rank_map=doc_id_to_rank_map,
|
||||
llm=llm,
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
all_doc_useful=reference_doc_ids is not None,
|
||||
)
|
||||
|
||||
# Capture outputs and errors
|
||||
llm_output = ""
|
||||
error: str | None = None
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in response_packets:
|
||||
if isinstance(packet, DanswerAnswerPiece):
|
||||
token = packet.answer_piece
|
||||
if token:
|
||||
llm_output += token
|
||||
elif isinstance(packet, StreamingError):
|
||||
error = packet.error
|
||||
elif isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
continue
|
||||
|
||||
yield get_json_line(packet.dict())
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
# This will be the issue 99% of the time
|
||||
error_packet = StreamingError(
|
||||
error="LLM failed to respond, have you set your API key?"
|
||||
)
|
||||
|
||||
yield get_json_line(error_packet.dict())
|
||||
return
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
db_citations = None
|
||||
if reference_db_search_docs:
|
||||
db_citations = translate_citations(
|
||||
citations_list=citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
gen_ai_response_message = partial_response(
|
||||
message=llm_output,
|
||||
token_count=len(llm_tokenizer_encode_func(llm_output)),
|
||||
citations=db_citations,
|
||||
error=error,
|
||||
)
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield get_json_line(msg_detail_response.dict())
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
error_packet = StreamingError(error="Failed to parse LLM output")
|
||||
|
||||
yield get_json_line(error_packet.dict())
|
||||
68
backend/danswer/chat/prompts.yaml
Normal file
68
backend/danswer/chat/prompts.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
prompts:
|
||||
# This id field can be left blank for other default prompts, however an id 0 prompt must exist
|
||||
# This is to act as a default
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Answer-Question"
|
||||
description: "Answers user questions using retrieved context!"
|
||||
# System Prompt (as shown in UI)
|
||||
system: >
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide
|
||||
grounded, accurate, and concise answers to diverse queries.
|
||||
|
||||
You always clearly communicate ANY UNCERTAINTY in your answer.
|
||||
# Task Prompt (as shown in UI)
|
||||
task: >
|
||||
Answer my query based on the documents provided.
|
||||
The documents may not all be relevant, ignore any documents that are not directly relevant
|
||||
to the most recent user query.
|
||||
|
||||
I have not read or seen any of the documents and do not want to read them.
|
||||
|
||||
If there are no relevant documents, refer to the chat history and existing knowledge.
|
||||
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
|
||||
# Format looks like: "October 16, 2023 14:30"
|
||||
datetime_aware: true
|
||||
# Prompts the LLM to include citations in the for [1], [2] etc.
|
||||
# which get parsed to match the passed in sources
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Summarize"
|
||||
description: "Summarize relevant information from retrieved context!"
|
||||
system: >
|
||||
You are a text summarizing assistant that highlights the most important knowledge from the
|
||||
context provided, prioritizing the information that relates to the user query.
|
||||
|
||||
You ARE NOT creative and always stick to the provided documents.
|
||||
If there are no documents, refer to the conversation history.
|
||||
|
||||
IMPORTANT: YOU ONLY SUMMARIZE THE IMPORTANT INFORMATION FROM THE PROVIDED DOCUMENTS,
|
||||
NEVER USE YOUR OWN KNOWLEDGE.
|
||||
task: >
|
||||
Summarize the documents provided in relation to the query below.
|
||||
NEVER refer to the documents by number, I do not have them in the same order as you.
|
||||
Do not make up any facts, only use what is in the documents.
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Paraphrase"
|
||||
description: "Recites information from retrieved context! Least creative but most safe!"
|
||||
system: >
|
||||
Quote and cite relevant information from provided context based on the user query.
|
||||
|
||||
You only provide quotes that are EXACT substrings from provided documents!
|
||||
|
||||
If there are no documents provided,
|
||||
simply tell the user that there are no documents to reference.
|
||||
|
||||
You NEVER generate new text or phrases outside of the citation.
|
||||
DO NOT explain your responses, only provide the quotes and NOTHING ELSE.
|
||||
task: >
|
||||
Provide EXACT quotes from the provided documents above. Do not generate any new text that is not
|
||||
directly from the documents.
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
@@ -1,10 +1,115 @@
|
||||
from uuid import UUID
|
||||
from typing import TypedDict
|
||||
|
||||
from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
|
||||
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
|
||||
from danswer.prompts.chat_tools import TOOL_TEMPLATE
|
||||
from danswer.prompts.chat_tools import USER_INPUT
|
||||
|
||||
|
||||
class ToolInfo(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class DanswerChatModelOut(BaseModel):
|
||||
model_raw: str
|
||||
action: str
|
||||
action_input: str
|
||||
|
||||
|
||||
def call_tool(
|
||||
model_actions: DanswerChatModelOut,
|
||||
user_id: UUID | None,
|
||||
) -> str:
|
||||
raise NotImplementedError("There are no additional tool integrations right now")
|
||||
|
||||
|
||||
def form_user_prompt_text(
|
||||
query: str,
|
||||
tool_text: str | None,
|
||||
hint_text: str | None,
|
||||
user_input_prompt: str = USER_INPUT,
|
||||
tool_less_prompt: str = TOOL_LESS_PROMPT,
|
||||
) -> str:
|
||||
user_prompt = tool_text or tool_less_prompt
|
||||
|
||||
user_prompt += user_input_prompt.format(user_input=query)
|
||||
|
||||
if hint_text:
|
||||
if user_prompt[-1] != "\n":
|
||||
user_prompt += "\n"
|
||||
user_prompt += "\nHint: " + hint_text
|
||||
|
||||
return user_prompt.strip()
|
||||
|
||||
|
||||
def form_tool_section_text(
|
||||
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
|
||||
) -> str | None:
|
||||
if not tools and not retrieval_enabled:
|
||||
return None
|
||||
|
||||
if retrieval_enabled and tools:
|
||||
tools.append(
|
||||
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
|
||||
)
|
||||
|
||||
tools_intro = []
|
||||
if tools:
|
||||
num_tools = len(tools)
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||
|
||||
prefix = "Must be one of " if num_tools > 1 else "Must be "
|
||||
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
return template.format(
|
||||
tool_overviews=tools_intro_text, tool_names=tool_names_text
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_FOLLOWUP,
|
||||
ignore_hint: bool = False,
|
||||
) -> str:
|
||||
# If multi-line query, it likely confuses the model more than helps
|
||||
if "\n" not in query:
|
||||
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
|
||||
else:
|
||||
optional_reminder = ""
|
||||
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"\nHint: {hint_text}\n"
|
||||
else:
|
||||
hint_text_spaced = ""
|
||||
|
||||
return tool_followup_prompt.format(
|
||||
tool_output=tool_output,
|
||||
optional_reminder=optional_reminder,
|
||||
hint=hint_text_spaced,
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_less_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
|
||||
) -> str:
|
||||
hint = f"Hint: {hint_text}" if hint_text else ""
|
||||
return tool_followup_prompt.format(
|
||||
context_str=tool_output, user_query=query, hint_text=hint
|
||||
).strip()
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
import os
|
||||
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DocumentIndexType
|
||||
|
||||
|
||||
#####
|
||||
# App Configs
|
||||
#####
|
||||
APP_HOST = "0.0.0.0"
|
||||
APP_PORT = 8080
|
||||
# API_PREFIX is used to prepend a base path for all API routes
|
||||
# generally used if using a reverse proxy which doesn't support stripping the `/api`
|
||||
# prefix from requests directed towards the API server. In these cases, set this to `/api`
|
||||
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
|
||||
|
||||
#####
|
||||
@@ -14,37 +20,34 @@ APP_PORT = 8080
|
||||
#####
|
||||
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
|
||||
# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer. Use this
|
||||
# if you want to use Danswer as a search engine only and/or you are not comfortable sending
|
||||
# anything to OpenAI. TODO: update this message once we support Azure / open source generative models.
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
|
||||
#####
|
||||
# Web Configs
|
||||
#####
|
||||
# WEB_DOMAIN is used to set the redirect_uri when doing OAuth with Google
|
||||
# TODO: investigate if this can be done cleaner by overwriting the redirect_uri
|
||||
# on the frontend and just sending a dummy value (or completely generating the URL)
|
||||
# on the frontend
|
||||
WEB_DOMAIN = os.environ.get("WEB_DOMAIN", "http://localhost:3000")
|
||||
# WEB_DOMAIN is used to set the redirect_uri after login flows
|
||||
# NOTE: if you are having problems accessing the Danswer web UI locally (especially
|
||||
# on Windows, try setting this to `http://127.0.0.1:3000` instead and see if that
|
||||
# fixes it)
|
||||
WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
|
||||
|
||||
#####
|
||||
# Auth Configs
|
||||
#####
|
||||
DISABLE_AUTH = os.environ.get("DISABLE_AUTH", "").lower() == "true"
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
# Turn off mask if admin users should see full credentials for data connectors.
|
||||
MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER", "smtp.gmail.com")
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT", "587"))
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
|
||||
SECRET = os.environ.get("SECRET", "")
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 86400)
|
||||
) # 1 day
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
|
||||
# restrict access to Danswer to only users with emails from those domains.
|
||||
@@ -60,26 +63,31 @@ VALID_EMAIL_DOMAINS = (
|
||||
if _VALID_EMAIL_DOMAINS_STR
|
||||
else []
|
||||
)
|
||||
|
||||
# OAuth Login Flow
|
||||
ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "").lower() != "false"
|
||||
OAUTH_TYPE = os.environ.get("OAUTH_TYPE", "google").lower()
|
||||
OAUTH_CLIENT_ID = os.environ.get(
|
||||
"OAUTH_CLIENT_ID", os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "")
|
||||
# Used for both Google OAuth2 and OIDC flows
|
||||
OAUTH_CLIENT_ID = (
|
||||
os.environ.get("OAUTH_CLIENT_ID", os.environ.get("GOOGLE_OAUTH_CLIENT_ID")) or ""
|
||||
)
|
||||
OAUTH_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "")
|
||||
OAUTH_CLIENT_SECRET = (
|
||||
os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET"))
|
||||
or ""
|
||||
)
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
|
||||
MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
|
||||
# for basic auth
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
)
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
|
||||
#####
|
||||
# DB Configs
|
||||
#####
|
||||
DOCUMENT_INDEX_NAME = "danswer_index" # Shared by vector/keyword indices
|
||||
DOCUMENT_INDEX_NAME = "danswer_index"
|
||||
# Vespa is now the default document index store for both keyword and vector
|
||||
DOCUMENT_INDEX_TYPE = os.environ.get(
|
||||
"DOCUMENT_INDEX_TYPE", DocumentIndexType.COMBINED.value
|
||||
@@ -91,21 +99,13 @@ VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
|
||||
VESPA_DEPLOYMENT_ZIP = (
|
||||
os.environ.get("VESPA_DEPLOYMENT_ZIP") or "/app/danswer/vespa-app.zip"
|
||||
)
|
||||
# Qdrant is Semantic Search Vector DB
|
||||
# Url / Key are used to connect to a remote Qdrant instance
|
||||
QDRANT_URL = os.environ.get("QDRANT_URL", "")
|
||||
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
|
||||
# Host / Port are used for connecting to local Qdrant instance
|
||||
QDRANT_HOST = os.environ.get("QDRANT_HOST") or "localhost"
|
||||
QDRANT_PORT = 6333
|
||||
# Typesense is the Keyword Search Engine
|
||||
TYPESENSE_HOST = os.environ.get("TYPESENSE_HOST") or "localhost"
|
||||
TYPESENSE_PORT = 8108
|
||||
TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "")
|
||||
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||
INDEX_BATCH_SIZE = 16
|
||||
try:
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
||||
# below are intended to match the env variables names used by the official postgres docker image
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
@@ -117,11 +117,21 @@ POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
POLL_CONNECTOR_OFFSET = 30 # Minutes overlap between poll windows
|
||||
|
||||
# Some calls to get information on expert users are quite costly especially with rate limiting
|
||||
# Since experts are not used in the actual user experience, currently it is turned off
|
||||
# for some connectors
|
||||
ENABLE_EXPENSIVE_EXPERT_CALLS = False
|
||||
|
||||
GOOGLE_DRIVE_INCLUDE_SHARED = False
|
||||
GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False
|
||||
GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False
|
||||
|
||||
FILE_CONNECTOR_TMP_STORAGE_PATH = os.environ.get(
|
||||
"FILE_CONNECTOR_TMP_STORAGE_PATH", "/home/file_connector_storage"
|
||||
)
|
||||
|
||||
# TODO these should be available for frontend configuration, via advanced options expandable
|
||||
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
|
||||
"WEB_CONNECTOR_IGNORED_CLASSES", "sidebar,footer"
|
||||
@@ -138,60 +148,29 @@ NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
== "true"
|
||||
)
|
||||
|
||||
#####
|
||||
# Query Configs
|
||||
#####
|
||||
NUM_RETURNED_HITS = 50
|
||||
NUM_RERANKED_RESULTS = 15
|
||||
# We feed in document chunks until we reach this token limit.
|
||||
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks
|
||||
# may be smaller which could result in passing in more total chunks
|
||||
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
|
||||
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
|
||||
CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("CONFLUENCE_CONNECTOR_LABELS_TO_SKIP", "").split(
|
||||
","
|
||||
)
|
||||
if ignored_tag
|
||||
]
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
GITHUB_CONNECTOR_BASE_URL = os.environ.get("GITHUB_CONNECTOR_BASE_URL") or None
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int(
|
||||
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (512 * 3)
|
||||
EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
|
||||
)
|
||||
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||
INCLUDE_METADATA = False
|
||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
|
||||
|
||||
|
||||
#####
|
||||
# Text Processing Configs
|
||||
# Indexing Configs
|
||||
#####
|
||||
CHUNK_SIZE = 512 # Tokens by embedding model
|
||||
CHUNK_OVERLAP = int(CHUNK_SIZE * 0.05) # 5% overlap
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
MINI_CHUNK_SIZE = 150
|
||||
|
||||
|
||||
#####
|
||||
# Encoder Model Endpoint Configs (Currently unused, running the models in memory)
|
||||
#####
|
||||
BI_ENCODER_HOST = "localhost"
|
||||
BI_ENCODER_PORT = 9000
|
||||
CROSS_ENCODER_HOST = "localhost"
|
||||
CROSS_ENCODER_PORT = 9000
|
||||
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
DYNAMIC_CONFIG_STORE = os.environ.get(
|
||||
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
|
||||
)
|
||||
DYNAMIC_CONFIG_DIR_PATH = os.environ.get("DYNAMIC_CONFIG_DIR_PATH", "/home/storage")
|
||||
# notset, debug, info, warning, error, or critical
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")
|
||||
# NOTE: Currently only supported in the Confluence and Google Drive connectors +
|
||||
# only handles some failures (Confluence = handles API call failures, Google
|
||||
# Drive = handles failures pulling files / parsing them)
|
||||
@@ -203,32 +182,57 @@ CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
|
||||
# fairly large amount of memory in order to increase substantially, since
|
||||
# each worker loads the embedding models into memory.
|
||||
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
|
||||
CHUNK_OVERLAP = 0
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
MINI_CHUNK_SIZE = 150
|
||||
# Timeout to wait for job's last update before killing it, in hours
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
|
||||
|
||||
|
||||
#####
|
||||
# Danswer Slack Bot Configs
|
||||
# Model Server Configs
|
||||
#####
|
||||
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
|
||||
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
|
||||
# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via
|
||||
# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value.
|
||||
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
|
||||
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
|
||||
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
|
||||
|
||||
# specify this env variable directly to have a different model server for the background
|
||||
# indexing job vs the api server so that background indexing does not effect query-time
|
||||
# performance
|
||||
INDEXING_MODEL_SERVER_HOST = (
|
||||
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
|
||||
)
|
||||
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
|
||||
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
|
||||
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
|
||||
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
DYNAMIC_CONFIG_STORE = os.environ.get(
|
||||
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
|
||||
)
|
||||
DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
"DANSWER_BOT_DISPLAY_ERROR_MSGS", ""
|
||||
).lower() not in [
|
||||
"false",
|
||||
"",
|
||||
]
|
||||
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get(
|
||||
"DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER", ""
|
||||
).lower() not in ["false", ""]
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
ENABLE_DANSWERBOT_REFLEXION = (
|
||||
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
|
||||
DYNAMIC_CONFIG_DIR_PATH = os.environ.get("DYNAMIC_CONFIG_DIR_PATH", "/home/storage")
|
||||
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
# used to allow the background indexing jobs to use a different embedding
|
||||
# model server than the API server
|
||||
CURRENT_PROCESS_IS_AN_INDEXING_JOB = (
|
||||
os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true"
|
||||
)
|
||||
ENABLE_SLACK_DOC_FEEDBACK = (
|
||||
os.environ.get("ENABLE_SLACK_DOC_FEEDBACK", "").lower() == "true"
|
||||
# Logs every model prompt and output, mostly used for development or exploration purposes
|
||||
LOG_ALL_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
# If set to `true` will enable additional logs about Vespa query performance
|
||||
# (time spent on finding the right docs + time spent fetching summaries from disk)
|
||||
LOG_VESPA_TIMING_INFORMATION = (
|
||||
os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true"
|
||||
)
|
||||
# Anonymous usage telemetry
|
||||
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
|
||||
# notset, debug, info, warning, error, or critical
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")
|
||||
|
||||
69
backend/danswer/configs/chat_configs.py
Normal file
69
backend/danswer/configs/chat_configs.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os
|
||||
|
||||
|
||||
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
NUM_RERANKED_RESULTS = 15
|
||||
|
||||
# May be less depending on model
|
||||
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
|
||||
# For Chat, need to keep enough space for history and other prompt pieces
|
||||
# ~3k input, half for docs, half for chat history + prompts
|
||||
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
|
||||
|
||||
# For selecting a different LLM question-answering prompt format
|
||||
# Valid values: default, cot, weak
|
||||
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
|
||||
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
|
||||
# Capped in Vespa at 0.5
|
||||
DOC_TIME_DECAY = float(
|
||||
os.environ.get("DOC_TIME_DECAY") or 0.5 # Hits limit at 2 years by default
|
||||
)
|
||||
BASE_RECENCY_DECAY = 0.5
|
||||
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
# Currently this next one is not configurable via env
|
||||
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
|
||||
DISABLE_LLM_FILTER_EXTRACTION = (
|
||||
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
|
||||
)
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_CHUNK_FILTER = (
|
||||
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
|
||||
)
|
||||
# Whether the LLM should be used to decide if a search would help given the chat history
|
||||
DISABLE_LLM_CHOOSE_SEARCH = (
|
||||
os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true"
|
||||
)
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||
INCLUDE_METADATA = False
|
||||
# Keyword Search Drop Stopwords
|
||||
# If user has changed the default model, would most likely be to use a multilingual
|
||||
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
||||
if os.environ.get("EDIT_KEYWORD_QUERY"):
|
||||
EDIT_KEYWORD_QUERY = os.environ.get("EDIT_KEYWORD_QUERY", "").lower() == "true"
|
||||
else:
|
||||
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
|
||||
# Weighting factor between Title and Content of documents during search, 1 for completely
|
||||
# Title based. Default heavily favors Content because Title is also included at the top of
|
||||
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
|
||||
# if the title is separated out. Title is most of a "boost" than a separate field.
|
||||
TITLE_CONTENT_RATIO = max(
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
|
||||
)
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
# For example "English,French,Spanish", be sure to use the "," separator
|
||||
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
|
||||
|
||||
# Stops streaming answers back to the UI if this pattern is seen:
|
||||
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
# The backend logic for this being True isn't fully supported yet
|
||||
HARD_DELETE_CHATS = False
|
||||
@@ -8,12 +8,17 @@ SOURCE_TYPE = "source_type"
|
||||
SOURCE_LINKS = "source_links"
|
||||
SOURCE_LINK = "link"
|
||||
SEMANTIC_IDENTIFIER = "semantic_identifier"
|
||||
TITLE = "title"
|
||||
SKIP_TITLE_EMBEDDING = "skip_title"
|
||||
SECTION_CONTINUATION = "section_continuation"
|
||||
EMBEDDINGS = "embeddings"
|
||||
TITLE_EMBEDDING = "title_embedding"
|
||||
ALLOWED_USERS = "allowed_users"
|
||||
ACCESS_CONTROL_LIST = "access_control_list"
|
||||
DOCUMENT_SETS = "document_sets"
|
||||
TIME_FILTER = "time_filter"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
MATCH_HIGHLIGHTS = "match_highlights"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
@@ -24,30 +29,43 @@ PUBLIC_DOC_PAT = "PUBLIC"
|
||||
PUBLIC_DOCUMENT_SET = "__PUBLIC"
|
||||
QUOTE = "quote"
|
||||
BOOST = "boost"
|
||||
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
|
||||
PRIMARY_OWNERS = "primary_owners"
|
||||
SECONDARY_OWNERS = "secondary_owners"
|
||||
RECENCY_BIAS = "recency_bias"
|
||||
HIDDEN = "hidden"
|
||||
SCORE = "score"
|
||||
ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
|
||||
# Prompt building constants:
|
||||
GENERAL_SEP_PAT = "\n-----\n"
|
||||
CODE_BLOCK_PAT = "\n```\n{}\n```\n"
|
||||
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
||||
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
|
||||
QUESTION_PAT = "Query:"
|
||||
THOUGHT_PAT = "Thought:"
|
||||
ANSWER_PAT = "Answer:"
|
||||
FINAL_ANSWER_PAT = "Final Answer:"
|
||||
UNCERTAINTY_PAT = "?"
|
||||
QUOTE_PAT = "Quote:"
|
||||
QUOTES_PAT_PLURAL = "Quotes:"
|
||||
INVALID_PAT = "Invalid:"
|
||||
# For chunking/processing chunks
|
||||
TITLE_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
# For combining attributes, doesn't have to be unique/perfect to work
|
||||
INDEX_SEPARATOR = "==="
|
||||
|
||||
|
||||
# Messages
|
||||
DISABLED_GEN_AI_MSG = (
|
||||
"Your System Admin has disabled the Generative AI functionalities of Danswer.\n"
|
||||
"Please contact them if you wish to have this enabled.\n"
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
INGESTION_API = "ingestion_api"
|
||||
SLACK = "slack"
|
||||
WEB = "web"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
GMAIL = "gmail"
|
||||
REQUESTTRACKER = "requesttracker"
|
||||
GITHUB = "github"
|
||||
GITLAB = "gitlab"
|
||||
GURU = "guru"
|
||||
BOOKSTACK = "bookstack"
|
||||
CONFLUENCE = "confluence"
|
||||
@@ -58,6 +76,13 @@ class DocumentSource(str, Enum):
|
||||
NOTION = "notion"
|
||||
ZULIP = "zulip"
|
||||
LINEAR = "linear"
|
||||
HUBSPOT = "hubspot"
|
||||
DOCUMENT360 = "document360"
|
||||
GONG = "gong"
|
||||
GOOGLE_SITES = "google_sites"
|
||||
ZENDESK = "zendesk"
|
||||
LOOPIO = "loopio"
|
||||
SHAREPOINT = "sharepoint"
|
||||
|
||||
|
||||
class DocumentIndexType(str, Enum):
|
||||
@@ -65,35 +90,12 @@ class DocumentIndexType(str, Enum):
|
||||
SPLIT = "split" # Typesense + Qdrant
|
||||
|
||||
|
||||
class DanswerGenAIModel(str, Enum):
|
||||
"""This represents the internal Danswer GenAI model which determines the class that is used
|
||||
to generate responses to the user query. Different models/services require different internal
|
||||
handling, this allows for modularity of implementation within Danswer"""
|
||||
|
||||
OPENAI = "openai-completion"
|
||||
OPENAI_CHAT = "openai-chat-completion"
|
||||
GPT4ALL = "gpt4all-completion"
|
||||
GPT4ALL_CHAT = "gpt4all-chat-completion"
|
||||
HUGGINGFACE = "huggingface-client-completion"
|
||||
HUGGINGFACE_CHAT = "huggingface-client-chat-completion"
|
||||
REQUEST = "request-completion"
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
class ModelHostType(str, Enum):
|
||||
"""For GenAI models interfaced via requests, different services have different
|
||||
expectations for what fields are included in the request"""
|
||||
|
||||
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
|
||||
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
|
||||
# https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183
|
||||
COLAB_DEMO = "colab-demo"
|
||||
# TODO support for Azure, AWS, GCP GenAI model hosting
|
||||
|
||||
|
||||
class QAFeedbackType(str, Enum):
|
||||
LIKE = "like" # User likes the answer, used for metrics
|
||||
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
||||
class AuthType(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
BASIC = "basic"
|
||||
GOOGLE_OAUTH = "google_oauth"
|
||||
OIDC = "oidc"
|
||||
SAML = "saml"
|
||||
|
||||
|
||||
class SearchFeedbackType(str, Enum):
|
||||
@@ -105,7 +107,7 @@ class SearchFeedbackType(str, Enum):
|
||||
|
||||
class MessageType(str, Enum):
|
||||
# Using OpenAI standards, Langchain equivalent shown in comment
|
||||
# System message is always constructed on the fly, not saved
|
||||
SYSTEM = "system" # SystemMessage
|
||||
USER = "user" # HumanMessage
|
||||
ASSISTANT = "assistant" # AIMessage
|
||||
DANSWER = "danswer" # FunctionMessage
|
||||
|
||||
59
backend/danswer/configs/danswerbot_configs.py
Normal file
59
backend/danswer/configs/danswerbot_configs.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import os
|
||||
|
||||
#####
|
||||
# Danswer Slack Bot Configs
|
||||
#####
|
||||
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
|
||||
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
|
||||
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
|
||||
)
|
||||
# How much of the available input context can be used for thread context
|
||||
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
|
||||
# Number of docs to display in "Reference Documents"
|
||||
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
|
||||
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
|
||||
)
|
||||
# If the LLM fails to answer, Danswer can still show the "Reference Documents"
|
||||
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get(
|
||||
"DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER", ""
|
||||
).lower() not in ["false", ""]
|
||||
# When Danswer is considering a message, what emoji does it react with
|
||||
DANSWER_REACT_EMOJI = os.environ.get("DANSWER_REACT_EMOJI") or "eyes"
|
||||
# When User needs more help, what should the emoji be
|
||||
DANSWER_FOLLOWUP_EMOJI = os.environ.get("DANSWER_FOLLOWUP_EMOJI") or "sos"
|
||||
# Should DanswerBot send an apology message if it's not able to find an answer
|
||||
# That way the user isn't confused as to why DanswerBot reacted but then said nothing
|
||||
# Off by default to be less intrusive (don't want to give a notif that just says we couldnt help)
|
||||
NOTIFY_SLACKBOT_NO_ANSWER = (
|
||||
os.environ.get("NOTIFY_SLACKBOT_NO_ANSWER", "").lower() == "true"
|
||||
)
|
||||
# Mostly for debugging purposes but it's for explaining what went wrong
|
||||
# if DanswerBot couldn't find an answer
|
||||
DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
"DANSWER_BOT_DISPLAY_ERROR_MSGS", ""
|
||||
).lower() not in [
|
||||
"false",
|
||||
"",
|
||||
]
|
||||
# Default is only respond in channels that are included by a slack config set in the UI
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
)
|
||||
# Auto detect query options like time cutoff or heavily favor recently updated docs
|
||||
DISABLE_DANSWER_BOT_FILTER_DETECT = (
|
||||
os.environ.get("DISABLE_DANSWER_BOT_FILTER_DETECT", "").lower() == "true"
|
||||
)
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
# This is the default for all DanswerBot channels unless the channel is configured individually
|
||||
# Set/unset by "Hide Non Answers"
|
||||
ENABLE_DANSWERBOT_REFLEXION = (
|
||||
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
|
||||
)
|
||||
# Currently not support chain of thought, probably will add back later
|
||||
DANSWER_BOT_DISABLE_COT = True
|
||||
|
||||
# Maximum Questions Per Minute, Default Uncapped
|
||||
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None
|
||||
# Maximum time to wait when a question is queued
|
||||
DANSWER_BOT_MAX_WAIT_TIME = int(os.environ.get("DANSWER_BOT_MAX_WAIT_TIME") or 180)
|
||||
@@ -1,28 +1,30 @@
|
||||
import os
|
||||
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.constants import ModelHostType
|
||||
|
||||
|
||||
#####
|
||||
# Embedding/Reranking Model Configs
|
||||
#####
|
||||
CHUNK_SIZE = 512
|
||||
# Important considerations when choosing models
|
||||
# Max tokens count needs to be high considering use case (at least 512)
|
||||
# Models used must be MIT or Apache license
|
||||
# Inference/Indexing speed
|
||||
|
||||
# https://huggingface.co/DOCUMENT_ENCODER_MODEL
|
||||
# The useable models configured as below must be SentenceTransformer compatible
|
||||
DOCUMENT_ENCODER_MODEL = (
|
||||
os.environ.get("DOCUMENT_ENCODER_MODEL") or "thenlper/gte-small"
|
||||
# This is not a good model anymore, but this default needs to be kept for not breaking existing
|
||||
# deployments, will eventually be retired/swapped for a different default model
|
||||
os.environ.get("DOCUMENT_ENCODER_MODEL")
|
||||
or "thenlper/gte-small"
|
||||
)
|
||||
# If the below is changed, Vespa deployment must also be changed
|
||||
DOC_EMBEDDING_DIM = 384
|
||||
DOC_EMBEDDING_DIM = int(os.environ.get("DOC_EMBEDDING_DIM") or 384)
|
||||
# Model should be chosen with 512 context size, ideally don't change this
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
NORMALIZE_EMBEDDINGS = (os.environ.get("SKIP_RERANKING") or "False").lower() == "true"
|
||||
NORMALIZE_EMBEDDINGS = (
|
||||
os.environ.get("NORMALIZE_EMBEDDINGS") or "False"
|
||||
).lower() == "true"
|
||||
# These are only used if reranking is turned off, to normalize the direct retrieval scores for display
|
||||
# Currently unused
|
||||
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
|
||||
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
|
||||
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
|
||||
@@ -30,80 +32,82 @@ ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
# This controls the minimum number of pytorch "threads" to allocate to the embedding
|
||||
# model. If torch finds more threads on its own, this value is not used.
|
||||
MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
|
||||
|
||||
|
||||
# Cross Encoder Settings
|
||||
SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true"
|
||||
ENABLE_RERANKING_ASYNC_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
|
||||
)
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
||||
)
|
||||
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html
|
||||
CROSS_ENCODER_MODEL_ENSEMBLE = [
|
||||
"cross-encoder/ms-marco-MiniLM-L-4-v2",
|
||||
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||
]
|
||||
# For score normalizing purposes, only way is to know the expected ranges
|
||||
CROSS_ENCODER_RANGE_MAX = 12
|
||||
CROSS_ENCODER_RANGE_MIN = -12
|
||||
CROSS_EMBED_CONTEXT_SIZE = 512
|
||||
|
||||
|
||||
# Better to keep it loose, surfacing more results better than missing results
|
||||
# Currently unused by Vespa
|
||||
SEARCH_DISTANCE_CUTOFF = 0.1 # Cosine similarity (currently), range of -1 to 1 with -1 being completely opposite
|
||||
# Unused currently, can't be used with the current default encoder model due to its output range
|
||||
SEARCH_DISTANCE_CUTOFF = 0
|
||||
|
||||
# Intent model max context size
|
||||
QUERY_MAX_CONTEXT_SIZE = 256
|
||||
|
||||
# Danswer custom Deep Learning Models
|
||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||
|
||||
|
||||
#####
|
||||
# Generative AI Model Configs
|
||||
#####
|
||||
# Other models should work as well, check the library/API compatibility.
|
||||
# But these are the models that have been verified to work with the existing prompts.
|
||||
# Using a different model may require some prompt tuning. See qa_prompts.py
|
||||
VERIFIED_MODELS = {
|
||||
DanswerGenAIModel.OPENAI: ["text-davinci-003"],
|
||||
DanswerGenAIModel.OPENAI_CHAT: ["gpt-3.5-turbo", "gpt-4"],
|
||||
DanswerGenAIModel.GPT4ALL: ["ggml-model-gpt4all-falcon-q4_0.bin"],
|
||||
DanswerGenAIModel.GPT4ALL_CHAT: ["ggml-model-gpt4all-falcon-q4_0.bin"],
|
||||
# The "chat" model below is actually "instruction finetuned" and does not support conversational
|
||||
DanswerGenAIModel.HUGGINGFACE.value: ["meta-llama/Llama-2-70b-chat-hf"],
|
||||
DanswerGenAIModel.HUGGINGFACE_CHAT.value: ["meta-llama/Llama-2-70b-hf"],
|
||||
# Created by Deepset.ai
|
||||
# https://huggingface.co/deepset/deberta-v3-large-squad2
|
||||
# Model provided with no modifications
|
||||
DanswerGenAIModel.TRANSFORMERS.value: ["deepset/deberta-v3-large-squad2"],
|
||||
}
|
||||
|
||||
# Sets the internal Danswer model class to use
|
||||
INTERNAL_MODEL_VERSION = os.environ.get(
|
||||
"INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value
|
||||
# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default,
|
||||
# be sure to use one that is LiteLLM compatible:
|
||||
# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables
|
||||
# The provider is the prefix before / in the model argument
|
||||
|
||||
# Additionally Danswer supports GPT4All and custom request library based models
|
||||
# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach
|
||||
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
|
||||
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
|
||||
# If using Azure, it's the engine name, for example: Danswer
|
||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo-0125"
|
||||
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
|
||||
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
|
||||
FAST_GEN_AI_MODEL_VERSION = (
|
||||
os.environ.get("FAST_GEN_AI_MODEL_VERSION") or GEN_AI_MODEL_VERSION
|
||||
)
|
||||
|
||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", "")
|
||||
|
||||
# If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version
|
||||
GEN_AI_MODEL_VERSION = os.environ.get(
|
||||
"GEN_AI_MODEL_VERSION",
|
||||
VERIFIED_MODELS.get(DanswerGenAIModel(INTERNAL_MODEL_VERSION), [""])[0],
|
||||
GEN_AI_API_KEY = (
|
||||
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None
|
||||
)
|
||||
|
||||
# If the Generative Model is hosted to accept requests (DanswerGenAIModel.REQUEST) then
|
||||
# set the two below to specify
|
||||
# - Where to hit the endpoint
|
||||
# - How should the request be formed
|
||||
GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "")
|
||||
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
|
||||
|
||||
# API Base, such as (for Azure): https://danswer.openai.azure.com/
|
||||
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
|
||||
# API Version, such as (for Azure): 2023-09-15-preview
|
||||
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
|
||||
# LiteLLM custom_llm_provider
|
||||
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
|
||||
# Override the auto-detection of LLM max context length
|
||||
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
|
||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
||||
# This next restriction is only used for chat ATM, used to expire old messages as needed
|
||||
GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
|
||||
# Number of tokens from chat history to include at maximum
|
||||
# 3000 should be enough context regardless of use, no need to include as much as possible
|
||||
# as this drives up the cost unnecessarily
|
||||
GEN_AI_HISTORY_CUTOFF = 3000
|
||||
# This is used when computing how much context space is available for documents
|
||||
# ahead of time in order to let the user know if they can "select" more documents
|
||||
# It represents a maximum "expected" number of input tokens from the latest user
|
||||
# message. At query time, we don't actually enforce this - we will only throw an
|
||||
# error if the total # of tokens exceeds the max input tokens.
|
||||
GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS = 512
|
||||
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)
|
||||
|
||||
# Danswer custom Deep Learning Models
|
||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||
|
||||
#####
|
||||
# OpenAI Azure
|
||||
#####
|
||||
API_BASE_OPENAI = os.environ.get("API_BASE_OPENAI", "")
|
||||
API_TYPE_OPENAI = os.environ.get("API_TYPE_OPENAI", "").lower()
|
||||
API_VERSION_OPENAI = os.environ.get("API_VERSION_OPENAI", "")
|
||||
# Deployment ID used interchangeably with "engine" parameter
|
||||
AZURE_DEPLOYMENT_ID = os.environ.get("AZURE_DEPLOYMENT_ID", "")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/README.md"} -->
|
||||
|
||||
# Writing a new Danswer Connector
|
||||
This README covers how to contribute a new Connector for Danswer. It includes an overview of the design, interfaces,
|
||||
and required changes.
|
||||
@@ -61,9 +63,9 @@ if __name__ == "__main__":
|
||||
### Additional Required Changes:
|
||||
#### Backend Changes
|
||||
- Add a new type to
|
||||
[DocumentSource](https://github.com/danswer-ai/danswer/blob/main/backend/danswer/configs/constants.py#L20)
|
||||
[DocumentSource](https://github.com/danswer-ai/danswer/blob/main/backend/danswer/configs/constants.py)
|
||||
- Add a mapping from DocumentSource (and optionally connector type) to the right connector class
|
||||
[here](https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/factory.py#L32)
|
||||
[here](https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/factory.py#L33)
|
||||
|
||||
#### Frontend Changes
|
||||
- Create the new connector directory and admin page under `danswer/web/src/app/admin/connectors/`
|
||||
|
||||
@@ -7,6 +7,8 @@ from typing import Any
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.bookstack.client import BookStackApiClient
|
||||
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -14,7 +16,6 @@ from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.text_processing import parse_html_page_basic
|
||||
|
||||
|
||||
class BookstackConnector(LoadConnector, PollConnector):
|
||||
@@ -72,13 +73,21 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
bookstack_client: BookStackApiClient, book: dict[str, Any]
|
||||
) -> Document:
|
||||
url = bookstack_client.build_app_url("/books/" + str(book.get("slug")))
|
||||
title = str(book.get("name", ""))
|
||||
text = book.get("name", "") + "\n" + book.get("description", "")
|
||||
updated_at_str = (
|
||||
str(book.get("updated_at")) if book.get("updated_at") is not None else None
|
||||
)
|
||||
return Document(
|
||||
id="book:" + str(book.get("id")),
|
||||
id="book__" + str(book.get("id")),
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Book: " + str(book.get("name")),
|
||||
metadata={"type": "book", "updated_at": str(book.get("updated_at"))},
|
||||
semantic_identifier="Book: " + title,
|
||||
title=title,
|
||||
doc_updated_at=time_str_to_utc(updated_at_str)
|
||||
if updated_at_str is not None
|
||||
else None,
|
||||
metadata={"type": "book"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -91,13 +100,23 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
+ "/chapter/"
|
||||
+ str(chapter.get("slug"))
|
||||
)
|
||||
title = str(chapter.get("name", ""))
|
||||
text = chapter.get("name", "") + "\n" + chapter.get("description", "")
|
||||
updated_at_str = (
|
||||
str(chapter.get("updated_at"))
|
||||
if chapter.get("updated_at") is not None
|
||||
else None
|
||||
)
|
||||
return Document(
|
||||
id="chapter:" + str(chapter.get("id")),
|
||||
id="chapter__" + str(chapter.get("id")),
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Chapter: " + str(chapter.get("name")),
|
||||
metadata={"type": "chapter", "updated_at": str(chapter.get("updated_at"))},
|
||||
semantic_identifier="Chapter: " + title,
|
||||
title=title,
|
||||
doc_updated_at=time_str_to_utc(updated_at_str)
|
||||
if updated_at_str is not None
|
||||
else None,
|
||||
metadata={"type": "chapter"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -105,13 +124,23 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
bookstack_client: BookStackApiClient, shelf: dict[str, Any]
|
||||
) -> Document:
|
||||
url = bookstack_client.build_app_url("/shelves/" + str(shelf.get("slug")))
|
||||
title = str(shelf.get("name", ""))
|
||||
text = shelf.get("name", "") + "\n" + shelf.get("description", "")
|
||||
updated_at_str = (
|
||||
str(shelf.get("updated_at"))
|
||||
if shelf.get("updated_at") is not None
|
||||
else None
|
||||
)
|
||||
return Document(
|
||||
id="shelf:" + str(shelf.get("id")),
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Shelf: " + str(shelf.get("name")),
|
||||
metadata={"type": "shelf", "updated_at": shelf.get("updated_at")},
|
||||
semantic_identifier="Shelf: " + title,
|
||||
title=title,
|
||||
doc_updated_at=time_str_to_utc(updated_at_str)
|
||||
if updated_at_str is not None
|
||||
else None,
|
||||
metadata={"type": "shelf"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -119,7 +148,7 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
bookstack_client: BookStackApiClient, page: dict[str, Any]
|
||||
) -> Document:
|
||||
page_id = str(page.get("id"))
|
||||
page_name = str(page.get("name"))
|
||||
title = str(page.get("name", ""))
|
||||
page_data = bookstack_client.get("/pages/" + page_id, {})
|
||||
url = bookstack_client.build_app_url(
|
||||
"/books/"
|
||||
@@ -127,17 +156,24 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
+ "/page/"
|
||||
+ str(page_data.get("slug"))
|
||||
)
|
||||
page_html = (
|
||||
"<h1>" + html.escape(page_name) + "</h1>" + str(page_data.get("html"))
|
||||
)
|
||||
page_html = "<h1>" + html.escape(title) + "</h1>" + str(page_data.get("html"))
|
||||
text = parse_html_page_basic(page_html)
|
||||
updated_at_str = (
|
||||
str(page_data.get("updated_at"))
|
||||
if page_data.get("updated_at") is not None
|
||||
else None
|
||||
)
|
||||
time.sleep(0.1)
|
||||
return Document(
|
||||
id="page:" + page_id,
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Page: " + str(page_name),
|
||||
metadata={"type": "page", "updated_at": page_data.get("updated_at")},
|
||||
semantic_identifier="Page: " + str(title),
|
||||
title=str(title),
|
||||
doc_updated_at=time_str_to_utc(updated_at_str)
|
||||
if updated_at_str is not None
|
||||
else None,
|
||||
metadata={"type": "page"},
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
|
||||
@@ -1,27 +1,30 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Collection
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.html_utils import format_document_soup
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import parse_html_page_basic
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -31,17 +34,12 @@ logger = setup_logger()
|
||||
# 3. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str]:
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
|
||||
"""Sample
|
||||
https://danswer.atlassian.net/wiki/spaces/1234abcd/overview
|
||||
wiki_base is danswer.atlassian.net/wiki
|
||||
wiki_base is https://danswer.atlassian.net/wiki
|
||||
space is 1234abcd
|
||||
"""
|
||||
if ".atlassian.net/wiki/spaces/" not in wiki_url:
|
||||
raise ValueError(
|
||||
"Not a valid Confluence Wiki Link, unable to extract wiki base and space names"
|
||||
)
|
||||
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
parsed_url.scheme
|
||||
@@ -53,6 +51,89 @@ def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str]:
|
||||
return wiki_base, space
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str]:
|
||||
"""Sample
|
||||
https://danswer.ai/confluence/display/1234abcd/overview
|
||||
wiki_base is https://danswer.ai/confluence
|
||||
space is 1234abcd
|
||||
"""
|
||||
# /display/ is always right before the space and at the end of the base url
|
||||
DISPLAY = "/display/"
|
||||
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
parsed_url.scheme
|
||||
+ "://"
|
||||
+ parsed_url.netloc
|
||||
+ parsed_url.path.split(DISPLAY)[0]
|
||||
)
|
||||
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
|
||||
return wiki_base, space
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
|
||||
is_confluence_cloud = ".atlassian.net/wiki/spaces/" in wiki_url
|
||||
|
||||
try:
|
||||
if is_confluence_cloud:
|
||||
wiki_base, space = _extract_confluence_keys_from_cloud_url(wiki_url)
|
||||
else:
|
||||
wiki_base, space = _extract_confluence_keys_from_datacenter_url(wiki_url)
|
||||
except Exception as e:
|
||||
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base and space names. Exception: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return wiki_base, space, is_confluence_cloud
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _get_user(user_id: str, confluence_client: Confluence) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
user_not_found = "Unknown User"
|
||||
|
||||
try:
|
||||
return confluence_client.get_user_details_by_accountid(user_id).get(
|
||||
"displayName", user_not_found
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unable to get the User Display Name with the id: '{user_id}' - {e}"
|
||||
)
|
||||
return user_not_found
|
||||
|
||||
|
||||
def parse_html_page(text: str, confluence_client: Confluence) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.attrs["ri:userkey"]
|
||||
)
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(user_id, confluence_client))
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def _comment_dfs(
|
||||
comments_str: str,
|
||||
comment_pages: Collection[dict[str, Any]],
|
||||
@@ -60,7 +141,9 @@ def _comment_dfs(
|
||||
) -> str:
|
||||
for comment_page in comment_pages:
|
||||
comment_html = comment_page["body"]["storage"]["value"]
|
||||
comments_str += "\nComment:\n" + parse_html_page_basic(comment_html)
|
||||
comments_str += "\nComment:\n" + parse_html_page(
|
||||
comment_html, confluence_client
|
||||
)
|
||||
child_comment_pages = confluence_client.get_page_child_by_type(
|
||||
comment_page["id"],
|
||||
type="comment",
|
||||
@@ -80,10 +163,17 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
wiki_page_url: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
# if a page has one of the labels specified in this list, we will just
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.wiki_base, self.space = extract_confluence_keys_from_url(wiki_page_url)
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.wiki_base, self.space, self.is_cloud = extract_confluence_keys_from_url(
|
||||
wiki_page_url
|
||||
)
|
||||
self.confluence_client: Confluence | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -91,9 +181,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
access_token = credentials["confluence_access_token"]
|
||||
self.confluence_client = Confluence(
|
||||
url=self.wiki_base,
|
||||
username=username,
|
||||
password=access_token,
|
||||
cloud=True,
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=username if self.is_cloud else None,
|
||||
password=access_token if self.is_cloud else None,
|
||||
token=access_token if not self.is_cloud else None,
|
||||
cloud=self.is_cloud,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -186,6 +278,17 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
return ""
|
||||
|
||||
def _fetch_labels(self, confluence_client: Confluence, page_id: str) -> list[str]:
|
||||
try:
|
||||
labels_response = confluence_client.get_page_labels(page_id)
|
||||
return [label["name"] for label in labels_response["results"]]
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
logger.exception("Ran into exception when fetching labels from Confluence")
|
||||
return []
|
||||
|
||||
def _get_doc_batch(
|
||||
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> tuple[list[Document], int]:
|
||||
@@ -197,9 +300,30 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
batch = self._fetch_pages(self.confluence_client, start_ind)
|
||||
for page in batch:
|
||||
last_modified_str = page["version"]["when"]
|
||||
author = cast(str | None, page["version"].get("by", {}).get("email"))
|
||||
last_modified = datetime.fromisoformat(last_modified_str)
|
||||
|
||||
if last_modified.tzinfo is None:
|
||||
# If no timezone info, assume it is UTC
|
||||
last_modified = last_modified.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# If not in UTC, translate it
|
||||
last_modified = last_modified.astimezone(timezone.utc)
|
||||
|
||||
if time_filter is None or time_filter(last_modified):
|
||||
page_id = page["id"]
|
||||
|
||||
# check disallowed labels
|
||||
if self.labels_to_skip:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
label_intersection = self.labels_to_skip.intersection(page_labels)
|
||||
if label_intersection:
|
||||
logger.info(
|
||||
f"Page with ID '{page_id}' has a label which has been "
|
||||
f"designated as disallowed: {label_intersection}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
page_html = (
|
||||
page["body"]
|
||||
.get("storage", page["body"].get("view", {}))
|
||||
@@ -209,10 +333,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if not page_html:
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = (
|
||||
page.get("title", "") + "\n" + parse_html_page_basic(page_html)
|
||||
)
|
||||
comments_text = self._fetch_comments(self.confluence_client, page["id"])
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
|
||||
doc_batch.append(
|
||||
@@ -221,9 +343,12 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
sections=[Section(link=page_url, text=page_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=[BasicExpertInfo(email=author)]
|
||||
if author
|
||||
else None,
|
||||
metadata={
|
||||
"Wiki Space Name": self.space,
|
||||
"Updated At": page["version"]["friendlyWhen"],
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -266,6 +391,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"])
|
||||
connector.load_credentials(
|
||||
{
|
||||
|
||||
158
backend/danswer/connectors/cross_connector_utils/file_utils.py
Normal file
158
backend/danswer/connectors/cross_connector_utils/file_utils.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import IO
|
||||
|
||||
import chardet
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def extract_metadata(line: str) -> dict | None:
|
||||
html_comment_pattern = r"<!--\s*DANSWER_METADATA=\{(.*?)\}\s*-->"
|
||||
hashtag_pattern = r"#DANSWER_METADATA=\{(.*?)\}"
|
||||
|
||||
html_comment_match = re.search(html_comment_pattern, line)
|
||||
hashtag_match = re.search(hashtag_pattern, line)
|
||||
|
||||
if html_comment_match:
|
||||
json_str = html_comment_match.group(1)
|
||||
elif hashtag_match:
|
||||
json_str = hashtag_match.group(1)
|
||||
else:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads("{" + json_str + "}")
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def read_pdf_file(file: IO[Any], file_name: str, pdf_pass: str | None = None) -> str:
|
||||
try:
|
||||
pdf_reader = PdfReader(file)
|
||||
|
||||
# If marked as encrypted and a password is provided, try to decrypt
|
||||
if pdf_reader.is_encrypted and pdf_pass is not None:
|
||||
decrypt_success = False
|
||||
if pdf_pass is not None:
|
||||
try:
|
||||
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
|
||||
except Exception:
|
||||
logger.error(f"Unable to decrypt pdf {file_name}")
|
||||
else:
|
||||
logger.info(f"No Password available to to decrypt pdf {file_name}")
|
||||
|
||||
if not decrypt_success:
|
||||
# By user request, keep files that are unreadable just so they
|
||||
# can be discoverable by title.
|
||||
return ""
|
||||
|
||||
return "\n".join(page.extract_text() for page in pdf_reader.pages)
|
||||
except PdfStreamError:
|
||||
logger.exception(f"PDF file {file_name} is not a valid PDF")
|
||||
except Exception:
|
||||
logger.exception(f"Failed to read PDF {file_name}")
|
||||
|
||||
# File is still discoverable by title
|
||||
# but the contents are not included as they cannot be parsed
|
||||
return ""
|
||||
|
||||
|
||||
def is_macos_resource_fork_file(file_name: str) -> bool:
|
||||
return os.path.basename(file_name).startswith("._") and file_name.startswith(
|
||||
"__MACOSX"
|
||||
)
|
||||
|
||||
|
||||
# To include additional metadata in the search index, add a .danswer_metadata.json file
|
||||
# to the zip file. This file should contain a list of objects with the following format:
|
||||
# [{ "filename": "file1.txt", "link": "https://example.com/file1.txt" }]
|
||||
def load_files_from_zip(
|
||||
zip_location: str | Path,
|
||||
ignore_macos_resource_fork_files: bool = True,
|
||||
ignore_dirs: bool = True,
|
||||
) -> Generator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]], None, None]:
|
||||
with zipfile.ZipFile(zip_location, "r") as zip_file:
|
||||
zip_metadata = {}
|
||||
try:
|
||||
metadata_file_info = zip_file.getinfo(".danswer_metadata.json")
|
||||
with zip_file.open(metadata_file_info, "r") as metadata_file:
|
||||
try:
|
||||
zip_metadata = json.load(metadata_file)
|
||||
if isinstance(zip_metadata, list):
|
||||
# convert list of dicts to dict of dicts
|
||||
zip_metadata = {d["filename"]: d for d in zip_metadata}
|
||||
except json.JSONDecodeError:
|
||||
logger.warn("Unable to load .danswer_metadata.json")
|
||||
except KeyError:
|
||||
logger.info("No .danswer_metadata.json file")
|
||||
|
||||
for file_info in zip_file.infolist():
|
||||
with zip_file.open(file_info.filename, "r") as file:
|
||||
if ignore_dirs and file_info.is_dir():
|
||||
continue
|
||||
|
||||
if ignore_macos_resource_fork_files and is_macos_resource_fork_file(
|
||||
file_info.filename
|
||||
):
|
||||
continue
|
||||
yield file_info, file, zip_metadata.get(file_info.filename, {})
|
||||
|
||||
|
||||
def detect_encoding(file_path: str | Path) -> str:
|
||||
with open(file_path, "rb") as file:
|
||||
raw_data = file.read(50000) # Read a portion of the file to guess encoding
|
||||
return chardet.detect(raw_data)["encoding"] or "utf-8"
|
||||
|
||||
|
||||
def read_file(
|
||||
file_reader: IO[Any], encoding: str = "utf-8", errors: str = "replace"
|
||||
) -> tuple[str, dict]:
|
||||
metadata = {}
|
||||
file_content_raw = ""
|
||||
for ind, line in enumerate(file_reader):
|
||||
try:
|
||||
line = line.decode(encoding) if isinstance(line, bytes) else line
|
||||
except UnicodeDecodeError:
|
||||
line = (
|
||||
line.decode(encoding, errors=errors)
|
||||
if isinstance(line, bytes)
|
||||
else line
|
||||
)
|
||||
|
||||
if ind == 0:
|
||||
metadata_or_none = extract_metadata(line)
|
||||
if metadata_or_none is not None:
|
||||
metadata = metadata_or_none
|
||||
else:
|
||||
file_content_raw += line
|
||||
else:
|
||||
file_content_raw += line
|
||||
|
||||
return file_content_raw, metadata
|
||||
|
||||
|
||||
def is_text_file_extension(file_name: str) -> bool:
|
||||
extensions = (
|
||||
".txt",
|
||||
".mdx",
|
||||
".md",
|
||||
".conf",
|
||||
".log",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".json",
|
||||
)
|
||||
return any(file_name.endswith(ext) for ext in extensions)
|
||||
164
backend/danswer/connectors/cross_connector_utils/html_utils.py
Normal file
164
backend/danswer/connectors/cross_connector_utils/html_utils.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import re
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
import bs4
|
||||
|
||||
from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_CLASSES
|
||||
from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_ELEMENTS
|
||||
|
||||
MINTLIFY_UNWANTED = ["sticky", "hidden"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedHTML:
|
||||
title: str | None
|
||||
cleaned_text: str
|
||||
|
||||
|
||||
def strip_excessive_newlines_and_spaces(document: str) -> str:
|
||||
# collapse repeated spaces into one
|
||||
document = re.sub(r" +", " ", document)
|
||||
# remove trailing spaces
|
||||
document = re.sub(r" +[\n\r]", "\n", document)
|
||||
# remove repeated newlines
|
||||
document = re.sub(r"[\n\r]+", "\n", document)
|
||||
return document.strip()
|
||||
|
||||
|
||||
def strip_newlines(document: str) -> str:
|
||||
# HTML might contain newlines which are just whitespaces to a browser
|
||||
return re.sub(r"[\n\r]+", " ", document)
|
||||
|
||||
|
||||
def format_document_soup(
|
||||
document: bs4.BeautifulSoup, table_cell_separator: str = "\t"
|
||||
) -> str:
|
||||
"""Format html to a flat text document.
|
||||
|
||||
The following goals:
|
||||
- Newlines from within the HTML are removed (as browser would ignore them as well).
|
||||
- Repeated newlines/spaces are removed (as browsers would ignore them).
|
||||
- Newlines only before and after headlines and paragraphs or when explicit (br or pre tag)
|
||||
- Table columns/rows are separated by newline
|
||||
- List elements are separated by newline and start with a hyphen
|
||||
"""
|
||||
text = ""
|
||||
list_element_start = False
|
||||
verbatim_output = 0
|
||||
in_table = False
|
||||
last_added_newline = False
|
||||
for e in document.descendants:
|
||||
verbatim_output -= 1
|
||||
if isinstance(e, bs4.element.NavigableString):
|
||||
if isinstance(e, (bs4.element.Comment, bs4.element.Doctype)):
|
||||
continue
|
||||
element_text = e.text
|
||||
if in_table:
|
||||
# Tables are represented in natural language with rows separated by newlines
|
||||
# Can't have newlines then in the table elements
|
||||
element_text = element_text.replace("\n", " ").strip()
|
||||
|
||||
# Some tags are translated to spaces but in the logic underneath this section, we
|
||||
# translate them to newlines as a browser should render them such as with br
|
||||
# This logic here avoids a space after newline when it shouldn't be there.
|
||||
if last_added_newline and element_text.startswith(" "):
|
||||
element_text = element_text[1:]
|
||||
last_added_newline = False
|
||||
|
||||
if element_text:
|
||||
content_to_add = (
|
||||
element_text
|
||||
if verbatim_output > 0
|
||||
else strip_newlines(element_text)
|
||||
)
|
||||
|
||||
# Don't join separate elements without any spacing
|
||||
if (text and not text[-1].isspace()) and (
|
||||
content_to_add and not content_to_add[0].isspace()
|
||||
):
|
||||
text += " "
|
||||
|
||||
text += content_to_add
|
||||
|
||||
list_element_start = False
|
||||
elif isinstance(e, bs4.element.Tag):
|
||||
# table is standard HTML element
|
||||
if e.name == "table":
|
||||
in_table = True
|
||||
# tr is for rows
|
||||
elif e.name == "tr" and in_table:
|
||||
text += "\n"
|
||||
# td for data cell, th for header
|
||||
elif e.name in ["td", "th"] and in_table:
|
||||
text += table_cell_separator
|
||||
elif e.name == "/table":
|
||||
in_table = False
|
||||
elif in_table:
|
||||
# don't handle other cases while in table
|
||||
pass
|
||||
|
||||
elif e.name in ["p", "div"]:
|
||||
if not list_element_start:
|
||||
text += "\n"
|
||||
elif e.name in ["h1", "h2", "h3", "h4"]:
|
||||
text += "\n"
|
||||
list_element_start = False
|
||||
last_added_newline = True
|
||||
elif e.name == "br":
|
||||
text += "\n"
|
||||
list_element_start = False
|
||||
last_added_newline = True
|
||||
elif e.name == "li":
|
||||
text += "\n- "
|
||||
list_element_start = True
|
||||
elif e.name == "pre":
|
||||
if verbatim_output <= 0:
|
||||
verbatim_output = len(list(e.childGenerator()))
|
||||
return strip_excessive_newlines_and_spaces(text)
|
||||
|
||||
|
||||
def parse_html_page_basic(text: str) -> str:
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def web_html_cleanup(
|
||||
page_content: str | bs4.BeautifulSoup,
|
||||
mintlify_cleanup_enabled: bool = True,
|
||||
additional_element_types_to_discard: list[str] | None = None,
|
||||
) -> ParsedHTML:
|
||||
if isinstance(page_content, str):
|
||||
soup = bs4.BeautifulSoup(page_content, "html.parser")
|
||||
else:
|
||||
soup = page_content
|
||||
|
||||
title_tag = soup.find("title")
|
||||
title = None
|
||||
if title_tag and title_tag.text:
|
||||
title = title_tag.text
|
||||
title_tag.extract()
|
||||
|
||||
# Heuristics based cleaning of elements based on css classes
|
||||
unwanted_classes = copy(WEB_CONNECTOR_IGNORED_CLASSES)
|
||||
if mintlify_cleanup_enabled:
|
||||
unwanted_classes.extend(MINTLIFY_UNWANTED)
|
||||
for undesired_element in unwanted_classes:
|
||||
[
|
||||
tag.extract()
|
||||
for tag in soup.find_all(
|
||||
class_=lambda x: x and undesired_element in x.split()
|
||||
)
|
||||
]
|
||||
|
||||
for undesired_tag in WEB_CONNECTOR_IGNORED_ELEMENTS:
|
||||
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
||||
|
||||
if additional_element_types_to_discard:
|
||||
for undesired_tag in additional_element_types_to_discard:
|
||||
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
||||
|
||||
# 200B is ZeroWidthSpace which we don't care for
|
||||
page_text = format_document_soup(soup).replace("\u200B", "")
|
||||
|
||||
return ParsedHTML(title=title, cleaned_text=page_text)
|
||||
@@ -0,0 +1,45 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.text_processing import is_valid_email
|
||||
|
||||
|
||||
def datetime_to_utc(dt: datetime) -> datetime:
|
||||
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
dt = parse(datetime_str)
|
||||
return datetime_to_utc(dt)
|
||||
|
||||
|
||||
def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
|
||||
if info.first_name and info.last_name:
|
||||
return f"{info.first_name} {info.middle_initial} {info.last_name}"
|
||||
|
||||
if info.display_name:
|
||||
return info.display_name
|
||||
|
||||
if info.email and is_valid_email(info.email):
|
||||
return info.email
|
||||
|
||||
if info.first_name:
|
||||
return info.first_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_experts_stores_representations(
|
||||
experts: list[BasicExpertInfo] | None,
|
||||
) -> list[str] | None:
|
||||
if not experts:
|
||||
return None
|
||||
|
||||
reps = [basic_expert_info_representation(owner) for owner in experts]
|
||||
return [owner for owner in reps if owner is not None]
|
||||
@@ -0,0 +1,86 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class RateLimitTriedTooManyTimesError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _RateLimitDecorator:
|
||||
"""Builds a generic wrapper/decorator for calls to external APIs that
|
||||
prevents making more than `max_calls` requests per `period`
|
||||
|
||||
Implementation inspired by the `ratelimit` library:
|
||||
https://github.com/tomasbasham/ratelimit.
|
||||
|
||||
NOTE: is not thread safe.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_calls: int,
|
||||
period: float, # in seconds
|
||||
sleep_time: float = 2, # in seconds
|
||||
sleep_backoff: float = 2, # applies exponential backoff
|
||||
max_num_sleep: int = 0,
|
||||
):
|
||||
self.max_calls = max_calls
|
||||
self.period = period
|
||||
self.sleep_time = sleep_time
|
||||
self.sleep_backoff = sleep_backoff
|
||||
self.max_num_sleep = max_num_sleep
|
||||
|
||||
self.call_history: list[float] = []
|
||||
self.curr_calls = 0
|
||||
|
||||
def __call__(self, func: F) -> F:
|
||||
@wraps(func)
|
||||
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
|
||||
# cleanup calls which are no longer relevant
|
||||
self._cleanup()
|
||||
|
||||
# check if we've exceeded the rate limit
|
||||
sleep_cnt = 0
|
||||
while len(self.call_history) == self.max_calls:
|
||||
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
|
||||
logger.info(
|
||||
f"Rate limit exceeded for function {func.__name__}. "
|
||||
f"Waiting {sleep_time} seconds before retrying."
|
||||
)
|
||||
time.sleep(sleep_time)
|
||||
sleep_cnt += 1
|
||||
if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep:
|
||||
raise RateLimitTriedTooManyTimesError(
|
||||
f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'"
|
||||
)
|
||||
|
||||
self._cleanup()
|
||||
|
||||
# add the current call to the call history
|
||||
self.call_history.append(time.monotonic())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return cast(F, wrapped_func)
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
curr_time = time.monotonic()
|
||||
time_to_expire_before = curr_time - self.period
|
||||
self.call_history = [
|
||||
call_time
|
||||
for call_time in self.call_history
|
||||
if call_time > time_to_expire_before
|
||||
]
|
||||
|
||||
|
||||
rate_limit_builder = _RateLimitDecorator
|
||||
@@ -0,0 +1,42 @@
|
||||
from collections.abc import Callable
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from retry import retry
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def retry_builder(
|
||||
tries: int = 10,
|
||||
delay: float = 0.1,
|
||||
max_delay: float | None = None,
|
||||
backoff: float = 2,
|
||||
jitter: tuple[float, float] | float = 1,
|
||||
) -> Callable[[F], F]:
|
||||
"""Builds a generic wrapper/decorator for calls to external APIs that
|
||||
may fail due to rate limiting, flakes, or other reasons. Applies expontential
|
||||
backoff with jitter to retry the call."""
|
||||
|
||||
@retry(
|
||||
tries=tries,
|
||||
delay=delay,
|
||||
max_delay=max_delay,
|
||||
backoff=backoff,
|
||||
jitter=jitter,
|
||||
logger=cast(Logger, logger),
|
||||
)
|
||||
def retry_with_default(func: F) -> F:
|
||||
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return cast(F, wrapped_func)
|
||||
|
||||
return retry_with_default
|
||||
@@ -8,10 +8,12 @@ from jira.resources import Issue
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
@@ -59,23 +61,32 @@ def fetch_jira_issues_batch(
|
||||
logger.warning(f"Found Jira object not of type Issue {jira}")
|
||||
continue
|
||||
|
||||
semantic_rep = (
|
||||
f"Jira Ticket Summary: {jira.fields.summary}\n"
|
||||
f"Description: {jira.fields.description}\n"
|
||||
+ "\n".join(
|
||||
[f"Comment: {comment.body}" for comment in jira.fields.comment.comments]
|
||||
)
|
||||
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
|
||||
[f"Comment: {comment.body}" for comment in jira.fields.comment.comments]
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
author = None
|
||||
try:
|
||||
author = BasicExpertInfo(
|
||||
display_name=jira.fields.creator.displayName,
|
||||
email=jira.fields.creator.emailAddress,
|
||||
)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=semantic_rep)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=jira.fields.summary,
|
||||
metadata={},
|
||||
doc_updated_at=time_str_to_utc(jira.fields.updated),
|
||||
primary_owners=[author] if author is not None else None,
|
||||
# TODO add secondary_owners if needed
|
||||
metadata={"label": jira.fields.labels} if jira.fields.labels else {},
|
||||
)
|
||||
)
|
||||
return doc_batch, len(batch)
|
||||
@@ -151,3 +162,17 @@ class JiraConnector(LoadConnector, PollConnector):
|
||||
start_ind += fetched_batch_size
|
||||
if fetched_batch_size < self.batch_size:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = JiraConnector(os.environ["JIRA_PROJECT_URL"])
|
||||
connector.load_credentials(
|
||||
{
|
||||
"jira_user_email": os.environ["JIRA_USER_EMAIL"],
|
||||
"jira_api_token": os.environ["JIRA_API_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
|
||||
196
backend/danswer/connectors/document360/connector.py
Normal file
196
backend/danswer/connectors/document360/connector.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
|
||||
# Limitations and Potential Improvements
|
||||
# 1. The "Categories themselves contain potentially relevant information" but they're not pulled in
|
||||
# 2. Only the HTML Articles are supported, Document360 also has a Markdown and "Block" format
|
||||
# 3. The contents are not as cleaned up as other HTML connectors
|
||||
|
||||
DOCUMENT360_BASE_URL = "https://preview.portal.document360.io/"
|
||||
DOCUMENT360_API_BASE_URL = "https://apihub.document360.io/v2"
|
||||
|
||||
|
||||
class Document360Connector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
categories: List[str] | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
portal_id: Optional[str] = None,
|
||||
api_token: Optional[str] = None,
|
||||
) -> None:
|
||||
self.portal_id = portal_id
|
||||
self.workspace = workspace
|
||||
self.categories = categories
|
||||
self.batch_size = batch_size
|
||||
self.api_token = api_token
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
self.api_token = credentials.get("document360_api_token")
|
||||
self.portal_id = credentials.get("portal_id")
|
||||
return None
|
||||
|
||||
# rate limiting set based on the enterprise plan: https://apidocs.document360.com/apidocs/rate-limiting
|
||||
# NOTE: retry will handle cases where user is not on enterprise plan - we will just hit the rate limit
|
||||
# and then retry after a period
|
||||
@retry_builder()
|
||||
@rate_limit_builder(max_calls=100, period=60)
|
||||
def _make_request(self, endpoint: str, params: Optional[dict] = None) -> Any:
|
||||
if not self.api_token:
|
||||
raise ConnectorMissingCredentialError("Document360")
|
||||
|
||||
headers = {"accept": "application/json", "api_token": self.api_token}
|
||||
|
||||
response = requests.get(
|
||||
f"{DOCUMENT360_API_BASE_URL}/{endpoint}", headers=headers, params=params
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()["data"]
|
||||
|
||||
def _get_workspace_id_by_name(self) -> str:
|
||||
projects = self._make_request("ProjectVersions")
|
||||
workspace_id = next(
|
||||
(
|
||||
project["id"]
|
||||
for project in projects
|
||||
if project["version_code_name"] == self.workspace
|
||||
),
|
||||
None,
|
||||
)
|
||||
if workspace_id is None:
|
||||
raise ValueError("Not able to find Workspace ID by the user provided name")
|
||||
|
||||
return workspace_id
|
||||
|
||||
def _get_articles_with_category(self, workspace_id: str) -> Any:
|
||||
all_categories = self._make_request(
|
||||
f"ProjectVersions/{workspace_id}/categories"
|
||||
)
|
||||
articles_with_category = []
|
||||
|
||||
for category in all_categories:
|
||||
if not self.categories or category["name"] in self.categories:
|
||||
for article in category["articles"]:
|
||||
articles_with_category.append(
|
||||
{"id": article["id"], "category_name": category["name"]}
|
||||
)
|
||||
for child_category in category["child_categories"]:
|
||||
for article in child_category["articles"]:
|
||||
articles_with_category.append(
|
||||
{
|
||||
"id": article["id"],
|
||||
"category_name": child_category["name"],
|
||||
}
|
||||
)
|
||||
return articles_with_category
|
||||
|
||||
def _process_articles(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.api_token is None:
|
||||
raise ConnectorMissingCredentialError("Document360")
|
||||
|
||||
workspace_id = self._get_workspace_id_by_name()
|
||||
articles = self._get_articles_with_category(workspace_id)
|
||||
|
||||
doc_batch: List[Document] = []
|
||||
|
||||
for article in articles:
|
||||
article_details = self._make_request(
|
||||
f"Articles/{article['id']}", {"langCode": "en"}
|
||||
)
|
||||
|
||||
updated_at = datetime.strptime(
|
||||
article_details["modified_at"], "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
).replace(tzinfo=timezone.utc)
|
||||
if start is not None and updated_at < start:
|
||||
continue
|
||||
if end is not None and updated_at > end:
|
||||
continue
|
||||
|
||||
authors = [
|
||||
author["email_id"]
|
||||
for author in article_details.get("authors", [])
|
||||
if author["email_id"]
|
||||
]
|
||||
|
||||
doc_link = f"{DOCUMENT360_BASE_URL}/{self.portal_id}/document/v1/view/{article['id']}"
|
||||
|
||||
html_content = article_details["html_content"]
|
||||
article_content = parse_html_page_basic(html_content)
|
||||
doc_text = (
|
||||
f"{article_details.get('description', '')}\n{article_content}".strip()
|
||||
)
|
||||
|
||||
document = Document(
|
||||
id=article_details["id"],
|
||||
sections=[Section(link=doc_link, text=doc_text)],
|
||||
source=DocumentSource.DOCUMENT360,
|
||||
semantic_identifier=article_details["title"],
|
||||
doc_updated_at=updated_at,
|
||||
primary_owners=authors,
|
||||
metadata={
|
||||
"workspace": self.workspace,
|
||||
"category": article["category_name"],
|
||||
},
|
||||
)
|
||||
|
||||
doc_batch.append(document)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._process_articles()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
return self._process_articles(start_datetime, end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
import os
|
||||
|
||||
document360_connector = Document360Connector(os.environ["DOCUMENT360_WORKSPACE"])
|
||||
document360_connector.load_credentials(
|
||||
{
|
||||
"portal_id": os.environ["DOCUMENT360_PORTAL_ID"],
|
||||
"document360_api_token": os.environ["DOCUMENT360_API_TOKEN"],
|
||||
}
|
||||
)
|
||||
|
||||
current = time.time()
|
||||
one_year_ago = current - 24 * 60 * 60 * 360
|
||||
latest_docs = document360_connector.poll_source(one_year_ago, current)
|
||||
|
||||
for doc in latest_docs:
|
||||
print(doc)
|
||||
@@ -5,22 +5,32 @@ from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.bookstack.connector import BookstackConnector
|
||||
from danswer.connectors.confluence.connector import ConfluenceConnector
|
||||
from danswer.connectors.danswer_jira.connector import JiraConnector
|
||||
from danswer.connectors.document360.connector import Document360Connector
|
||||
from danswer.connectors.file.connector import LocalFileConnector
|
||||
from danswer.connectors.github.connector import GithubConnector
|
||||
from danswer.connectors.gitlab.connector import GitlabConnector
|
||||
from danswer.connectors.gmail.connector import GmailConnector
|
||||
from danswer.connectors.gong.connector import GongConnector
|
||||
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from danswer.connectors.google_site.connector import GoogleSitesConnector
|
||||
from danswer.connectors.guru.connector import GuruConnector
|
||||
from danswer.connectors.hubspot.connector import HubSpotConnector
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import EventConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.linear.connector import LinearConnector
|
||||
from danswer.connectors.loopio.connector import LoopioConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.connectors.notion.connector import NotionConnector
|
||||
from danswer.connectors.productboard.connector import ProductboardConnector
|
||||
from danswer.connectors.requesttracker.connector import RequestTrackerConnector
|
||||
from danswer.connectors.sharepoint.connector import SharepointConnector
|
||||
from danswer.connectors.slab.connector import SlabConnector
|
||||
from danswer.connectors.slack.connector import SlackLoadConnector
|
||||
from danswer.connectors.slack.connector import SlackPollConnector
|
||||
from danswer.connectors.slack.load_connector import SlackLoadConnector
|
||||
from danswer.connectors.web.connector import WebConnector
|
||||
from danswer.connectors.zendesk.connector import ZendeskConnector
|
||||
from danswer.connectors.zulip.connector import ZulipConnector
|
||||
|
||||
|
||||
@@ -40,6 +50,8 @@ def identify_connector_class(
|
||||
InputType.POLL: SlackPollConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
DocumentSource.GITLAB: GitlabConnector,
|
||||
DocumentSource.GOOGLE_DRIVE: GoogleDriveConnector,
|
||||
DocumentSource.BOOKSTACK: BookstackConnector,
|
||||
DocumentSource.CONFLUENCE: ConfluenceConnector,
|
||||
@@ -48,8 +60,16 @@ def identify_connector_class(
|
||||
DocumentSource.SLAB: SlabConnector,
|
||||
DocumentSource.NOTION: NotionConnector,
|
||||
DocumentSource.ZULIP: ZulipConnector,
|
||||
DocumentSource.REQUESTTRACKER: RequestTrackerConnector,
|
||||
DocumentSource.GURU: GuruConnector,
|
||||
DocumentSource.LINEAR: LinearConnector,
|
||||
DocumentSource.HUBSPOT: HubSpotConnector,
|
||||
DocumentSource.DOCUMENT360: Document360Connector,
|
||||
DocumentSource.GONG: GongConnector,
|
||||
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
|
||||
DocumentSource.ZENDESK: ZendeskConnector,
|
||||
DocumentSource.LOOPIO: LoopioConnector,
|
||||
DocumentSource.SHAREPOINT: SharepointConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user