mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-28 21:25:44 +00:00
Compare commits
482 Commits
web-docker
...
v0.4.15
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e31d6be4ce | ||
|
|
e6a92aa936 | ||
|
|
a54ea9f9fa | ||
|
|
73a92c046d | ||
|
|
459bd46846 | ||
|
|
445f7e70ba | ||
|
|
ca893f9918 | ||
|
|
1be1959d80 | ||
|
|
1654378850 | ||
|
|
d6d391d244 | ||
|
|
7c283b090d | ||
|
|
40226678af | ||
|
|
288e6fa606 | ||
|
|
5307d38472 | ||
|
|
d619602a6f | ||
|
|
348a2176f0 | ||
|
|
89b6da36a6 | ||
|
|
036d5c737e | ||
|
|
60a87d9472 | ||
|
|
eb9bb56829 | ||
|
|
d151082871 | ||
|
|
e4b1f5b963 | ||
|
|
3938a053aa | ||
|
|
7932e764d6 | ||
|
|
fb6695a983 | ||
|
|
015f415b71 | ||
|
|
96b582070b | ||
|
|
4a0a927a64 | ||
|
|
ea9a9cb553 | ||
|
|
38af12ab97 | ||
|
|
1b3154188d | ||
|
|
1f321826ad | ||
|
|
cbfbe4e5d8 | ||
|
|
3aa0e0124b | ||
|
|
f2f60c9cc0 | ||
|
|
6c32821ad4 | ||
|
|
d839595330 | ||
|
|
e422f96dff | ||
|
|
d28f460330 | ||
|
|
8e441d975d | ||
|
|
5c78af1f07 | ||
|
|
e325e063ed | ||
|
|
c81b45300b | ||
|
|
26a1e963d1 | ||
|
|
2a983263c7 | ||
|
|
2a37c95a5e | ||
|
|
c277a74f82 | ||
|
|
e4b31cd0d9 | ||
|
|
a40d2a1e2e | ||
|
|
c9fb99d719 | ||
|
|
a4d71e08aa | ||
|
|
546bfbd24b | ||
|
|
27824d6cc6 | ||
|
|
9d5c4ad634 | ||
|
|
9b32003816 | ||
|
|
8bc4123ed7 | ||
|
|
d58aaf7a59 | ||
|
|
a0056a1b3c | ||
|
|
d2584c773a | ||
|
|
807bef8ada | ||
|
|
5afddacbb2 | ||
|
|
4fb6a88f1e | ||
|
|
7057be6a88 | ||
|
|
91be8e7bfb | ||
|
|
9651ea828b | ||
|
|
6ee74bd0d1 | ||
|
|
48a0d29a5c | ||
|
|
6ff8e6c0ea | ||
|
|
2470c68506 | ||
|
|
866bc803b1 | ||
|
|
9c6084bd0d | ||
|
|
a0b46c60c6 | ||
|
|
4029233df0 | ||
|
|
6c88c0156c | ||
|
|
33332d08f2 | ||
|
|
17005fb705 | ||
|
|
48a7fe80b1 | ||
|
|
1276732409 | ||
|
|
f91b92a898 | ||
|
|
6222f533be | ||
|
|
1b49d17239 | ||
|
|
2f5f19642e | ||
|
|
6db4634871 | ||
|
|
5cfed45cef | ||
|
|
581ffde35a | ||
|
|
6313e6d91d | ||
|
|
c09c94bf32 | ||
|
|
0e8ba111c8 | ||
|
|
2ba24b1734 | ||
|
|
44820b4909 | ||
|
|
eb3e7610fc | ||
|
|
7fbbb174bb | ||
|
|
3854ca11af | ||
|
|
e95bfa0e0b | ||
|
|
4848b5f1de | ||
|
|
7ba5c434fa | ||
|
|
59bf5ba848 | ||
|
|
f66c33380c | ||
|
|
115650ce9f | ||
|
|
7aa3602fca | ||
|
|
864c552a17 | ||
|
|
07b2ed3d8f | ||
|
|
38290057f2 | ||
|
|
2344edf158 | ||
|
|
86d1804eb0 | ||
|
|
1ebae50d0c | ||
|
|
a9fbaa396c | ||
|
|
27d5f69427 | ||
|
|
5d98421ae8 | ||
|
|
6b561b8ca9 | ||
|
|
2dc7e64dd7 | ||
|
|
5230f7e22f | ||
|
|
a595d43ae3 | ||
|
|
ee561f42ff | ||
|
|
f00b3d76b3 | ||
|
|
e4984153c0 | ||
|
|
87fadb07ea | ||
|
|
2b07c102f9 | ||
|
|
e93de602c3 | ||
|
|
1c77395503 | ||
|
|
cdf6089b3e | ||
|
|
d01f46af2b | ||
|
|
b83f435bb0 | ||
|
|
25b3dacaba | ||
|
|
a1e638a73d | ||
|
|
bd1e0c5969 | ||
|
|
4d295ab97d | ||
|
|
6fe3eeaa48 | ||
|
|
078d5defbb | ||
|
|
0d52e99bd4 | ||
|
|
1b864a00e4 | ||
|
|
dae4f6a0bd | ||
|
|
f63d0ca3ad | ||
|
|
da31da33e7 | ||
|
|
56b175f597 | ||
|
|
1b311d092e | ||
|
|
6ee1292757 | ||
|
|
017af052be | ||
|
|
e7f81d1688 | ||
|
|
b6bd818e60 | ||
|
|
36da2e4b27 | ||
|
|
c7af6a4601 | ||
|
|
e90c66c1b6 | ||
|
|
8c312482c1 | ||
|
|
e50820e65e | ||
|
|
991ee79e47 | ||
|
|
3e645a510e | ||
|
|
08c6e821e7 | ||
|
|
47a550221f | ||
|
|
511f619212 | ||
|
|
6c51f001dc | ||
|
|
09a11b5e1a | ||
|
|
aa0f7abdac | ||
|
|
7c8f8dba17 | ||
|
|
39982e5fdc | ||
|
|
5e0de111f9 | ||
|
|
727d80f168 | ||
|
|
146f85936b | ||
|
|
e06f8a0a4b | ||
|
|
f0888f2f61 | ||
|
|
d35d7ee833 | ||
|
|
c5bb3fde94 | ||
|
|
79190030a5 | ||
|
|
8e8f262ed3 | ||
|
|
ac14369716 | ||
|
|
de4d8e9a65 | ||
|
|
0b384c5b34 | ||
|
|
fa049f4f98 | ||
|
|
72d6a0ef71 | ||
|
|
ae4e643266 | ||
|
|
a7da07afc0 | ||
|
|
7f1bb67e52 | ||
|
|
982b1b0c49 | ||
|
|
2db128fb36 | ||
|
|
3ebac6256f | ||
|
|
1a3ec59610 | ||
|
|
581cb827bb | ||
|
|
393b3c9343 | ||
|
|
2035e9f39c | ||
|
|
52c3a5e9d2 | ||
|
|
3e45a41617 | ||
|
|
415960564d | ||
|
|
ed550986a6 | ||
|
|
60dd77393d | ||
|
|
3fe5313b02 | ||
|
|
bd0925611a | ||
|
|
de6d040349 | ||
|
|
38da3128d8 | ||
|
|
e47da0d688 | ||
|
|
2c0e0c5f11 | ||
|
|
29d57f6354 | ||
|
|
369e607631 | ||
|
|
f03f97307f | ||
|
|
145cdb69b7 | ||
|
|
9310a8edc2 | ||
|
|
2140f80891 | ||
|
|
52dab23295 | ||
|
|
91c9b2eb42 | ||
|
|
5764cdd469 | ||
|
|
8fea6d7f64 | ||
|
|
5324b15397 | ||
|
|
8be42a5f98 | ||
|
|
062dc98719 | ||
|
|
43557f738b | ||
|
|
b5aa7370a2 | ||
|
|
4ba6e45128 | ||
|
|
d6e5a98a22 | ||
|
|
20c4cdbdda | ||
|
|
0d814939ee | ||
|
|
7d2b0ffcc5 | ||
|
|
8c6cd661f5 | ||
|
|
5d552705aa | ||
|
|
1ee8ee9e8b | ||
|
|
f0b2b57d81 | ||
|
|
5c12a3e872 | ||
|
|
3af81ca96b | ||
|
|
f55e5415bb | ||
|
|
3d434c2c9e | ||
|
|
90ec156791 | ||
|
|
8ba48e24a6 | ||
|
|
e34bcbbd06 | ||
|
|
db319168f8 | ||
|
|
010ce5395f | ||
|
|
98a58337a7 | ||
|
|
733d4e666b | ||
|
|
2937fe9e7d | ||
|
|
457527ac86 | ||
|
|
7cc51376f2 | ||
|
|
7278d45552 | ||
|
|
1c343bbee7 | ||
|
|
bdcfb39724 | ||
|
|
694d20ea8f | ||
|
|
45402d0755 | ||
|
|
69740ba3d5 | ||
|
|
6162283beb | ||
|
|
44284f7912 | ||
|
|
775ca5787b | ||
|
|
c6e49a3034 | ||
|
|
9c8cfd9175 | ||
|
|
fc3ed76d12 | ||
|
|
a2597d5f21 | ||
|
|
af588461d2 | ||
|
|
460e61b3a7 | ||
|
|
c631ac0c3a | ||
|
|
10be91a8cc | ||
|
|
eadad34a77 | ||
|
|
b19d88a151 | ||
|
|
e33b469915 | ||
|
|
719fc06604 | ||
|
|
d7a704c0d9 | ||
|
|
7a408749cf | ||
|
|
d9acd03a85 | ||
|
|
af94c092e7 | ||
|
|
f55a4ef9bd | ||
|
|
6c6e33e001 | ||
|
|
336c046e5d | ||
|
|
9a9b89f073 | ||
|
|
89fac98534 | ||
|
|
65b65518de | ||
|
|
0c827d1e6c | ||
|
|
1984f2c1ca | ||
|
|
50f006557f | ||
|
|
c00bd44bcc | ||
|
|
680aca68e5 | ||
|
|
22a2f86fb9 | ||
|
|
c055dc1535 | ||
|
|
81e9880d9d | ||
|
|
3466f6d3a4 | ||
|
|
91cf45165f | ||
|
|
ee2a5bbf49 | ||
|
|
153007c57c | ||
|
|
fa8cc10063 | ||
|
|
2c3ba5f021 | ||
|
|
e3ef620094 | ||
|
|
40369e0538 | ||
|
|
d6c5c65b51 | ||
|
|
7b16cb9562 | ||
|
|
ef4f06a375 | ||
|
|
17cc262f5d | ||
|
|
680482bd06 | ||
|
|
64874d2737 | ||
|
|
00ade322f1 | ||
|
|
eab5d054d5 | ||
|
|
a09d60d7d0 | ||
|
|
f17dc52b37 | ||
|
|
c1862e961b | ||
|
|
6b46a71cb5 | ||
|
|
9ae3a4af7f | ||
|
|
328b96c9ff | ||
|
|
bac34a47b2 | ||
|
|
15934ee268 | ||
|
|
fe975c3357 | ||
|
|
8bf483904d | ||
|
|
db338bfddf | ||
|
|
ae02a5199a | ||
|
|
4b44073d9a | ||
|
|
ce36530c79 | ||
|
|
39d69838c5 | ||
|
|
e11f0f6202 | ||
|
|
ce870ff577 | ||
|
|
a52711967f | ||
|
|
67a4eb6f6f | ||
|
|
9599388db8 | ||
|
|
f82ae158ea | ||
|
|
670de6c00d | ||
|
|
56c52bddff | ||
|
|
3984350ff9 | ||
|
|
f799d9aa11 | ||
|
|
529f2c8c2d | ||
|
|
b4683dc841 | ||
|
|
db8ce61ff4 | ||
|
|
d016e8335e | ||
|
|
0c295d1de5 | ||
|
|
e9f273d99a | ||
|
|
428f5edd21 | ||
|
|
50170cc97e | ||
|
|
7503f8f37b | ||
|
|
92de6acc6f | ||
|
|
65d5808ea7 | ||
|
|
061dab7f37 | ||
|
|
e65d9e155d | ||
|
|
50f799edf4 | ||
|
|
c1d8f6cb66 | ||
|
|
6c71bc05ea | ||
|
|
123ec4342a | ||
|
|
7253316b9e | ||
|
|
4ae924662c | ||
|
|
094eea2742 | ||
|
|
8178d536b4 | ||
|
|
5cafc96cae | ||
|
|
3e39a921b0 | ||
|
|
98b2507045 | ||
|
|
3dfe17a54d | ||
|
|
b4675082b1 | ||
|
|
287a706e89 | ||
|
|
ba58208a85 | ||
|
|
694e9e8679 | ||
|
|
9e30ec1f1f | ||
|
|
1b56c75527 | ||
|
|
b07fdbf1d1 | ||
|
|
54c2547d89 | ||
|
|
58b5e25c97 | ||
|
|
4e15ba78d5 | ||
|
|
c798ade127 | ||
|
|
93cc5a9e77 | ||
|
|
7746375bfd | ||
|
|
c6d094b2ee | ||
|
|
7a855192c3 | ||
|
|
c3577cf346 | ||
|
|
722a1dd919 | ||
|
|
e4999266ca | ||
|
|
f294dba095 | ||
|
|
03105ad551 | ||
|
|
4b0ff95b26 | ||
|
|
ff06d62acf | ||
|
|
26fee36ed4 | ||
|
|
428439447e | ||
|
|
e8cfbc1dd8 | ||
|
|
486b0ecb31 | ||
|
|
8c324f8f01 | ||
|
|
5a577f9a00 | ||
|
|
e6d5b95b4a | ||
|
|
cc0320b50a | ||
|
|
36afa9370f | ||
|
|
7c9d037b7c | ||
|
|
a2065f018a | ||
|
|
b723627e0c | ||
|
|
180b592afe | ||
|
|
e8306b0fa5 | ||
|
|
64ee5ffff5 | ||
|
|
ead6a851cc | ||
|
|
73575f22d8 | ||
|
|
be5dd3eefb | ||
|
|
f18aa2368e | ||
|
|
3ec559ade2 | ||
|
|
4d0794f4f5 | ||
|
|
64a042b94d | ||
|
|
fa3a3d348c | ||
|
|
a0e10ac9c2 | ||
|
|
e1ece4a27a | ||
|
|
260149b35a | ||
|
|
5f2737f9ee | ||
|
|
e1d8b88318 | ||
|
|
fc5337d4db | ||
|
|
bd9335e832 | ||
|
|
cbc53fd500 | ||
|
|
7a3c102c74 | ||
|
|
4274c114c5 | ||
|
|
d56e6c495a | ||
|
|
c9160c705a | ||
|
|
3bc46ef40e | ||
|
|
ff59858327 | ||
|
|
643176407c | ||
|
|
eacfd8f33f | ||
|
|
f6fb963419 | ||
|
|
16e023a8ce | ||
|
|
b79820a309 | ||
|
|
754b735174 | ||
|
|
58c305a539 | ||
|
|
26bc785625 | ||
|
|
09da456bba | ||
|
|
da43bac456 | ||
|
|
adcbd354f4 | ||
|
|
41fbaf5698 | ||
|
|
0b83396c4d | ||
|
|
785d7736ed | ||
|
|
9a9a879aee | ||
|
|
7b36f7aa4f | ||
|
|
8d74176348 | ||
|
|
713d325f42 | ||
|
|
f34b26b3d0 | ||
|
|
a2349af65c | ||
|
|
914dc27a8f | ||
|
|
0b6e85c26b | ||
|
|
291a3f9ca0 | ||
|
|
8dfba97c09 | ||
|
|
9e0b6aa531 | ||
|
|
1fb47d70b3 | ||
|
|
25d40f8daa | ||
|
|
154cdec0db | ||
|
|
f5deb37fde | ||
|
|
750c1df0bb | ||
|
|
e608febb7f | ||
|
|
c8e10282b2 | ||
|
|
d7f66ba8c4 | ||
|
|
4d5a39628f | ||
|
|
b6d0ecec4f | ||
|
|
14a39e88e8 | ||
|
|
ea71b9830c | ||
|
|
a9834853ef | ||
|
|
8b10535c93 | ||
|
|
e1e1f036a7 | ||
|
|
0c642f25dd | ||
|
|
3788041115 | ||
|
|
5a75470d23 | ||
|
|
81aada7c0f | ||
|
|
e4a08c5546 | ||
|
|
61d096533c | ||
|
|
0543abac9a | ||
|
|
1d0ce49c05 | ||
|
|
6e9d7acb9c | ||
|
|
026652d827 | ||
|
|
2363698c20 | ||
|
|
0ea257d030 | ||
|
|
d141e637d0 | ||
|
|
4b53cb56a6 | ||
|
|
b690ae05b4 | ||
|
|
fbdf882299 | ||
|
|
44d57f1b53 | ||
|
|
d2b58bdb40 | ||
|
|
aa98200bec | ||
|
|
32c37f8b17 | ||
|
|
008a91bff0 | ||
|
|
9a3613eb44 | ||
|
|
90d5b41901 | ||
|
|
8688226003 | ||
|
|
97d058b8b2 | ||
|
|
26ef5b897d | ||
|
|
dfd233b985 | ||
|
|
2dab9c576c | ||
|
|
a9f5952510 | ||
|
|
94018e83b0 | ||
|
|
6e5d9f33d2 | ||
|
|
57452b1030 | ||
|
|
d6ea92b185 | ||
|
|
88db722ea4 | ||
|
|
ba872a0f7f | ||
|
|
cc9cb202cd | ||
|
|
bbae63b769 | ||
|
|
4413c0df36 | ||
|
|
286dc7e04a | ||
|
|
0ee1bb2400 | ||
|
|
a4f2693819 | ||
|
|
079fdee79f | ||
|
|
5c997c0322 | ||
|
|
b9bae9a011 | ||
|
|
f76e5e06c1 | ||
|
|
fa698cd8fd | ||
|
|
125d1a3e1f | ||
|
|
818dfd0413 | ||
|
|
51b4e63218 | ||
|
|
73b063b66c |
25
.github/pull_request_template.md
vendored
Normal file
25
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
## Description
|
||||
[Provide a brief description of the changes in this PR]
|
||||
|
||||
|
||||
## How Has This Been Tested?
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk
|
||||
[Any know risks or failure modes to point out to reviewers]
|
||||
|
||||
|
||||
## Related Issue(s)
|
||||
[If applicable, link to the issue(s) this PR addresses]
|
||||
|
||||
|
||||
## Checklist:
|
||||
- [ ] All of the automated tests pass
|
||||
- [ ] All PR comments are addressed and marked resolved
|
||||
- [ ] If there are migrations, they have been rebased to latest main
|
||||
- [ ] If there are new dependencies, they are added to the requirements
|
||||
- [ ] If there are new environment variables, they are added to all of the deployment methods
|
||||
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- [ ] Docker images build and basic functionalities work
|
||||
- [ ] Author has done a final read through of the PR right before merge
|
||||
33
.github/workflows/docker-build-backend-container-on-merge-group.yml
vendored
Normal file
33
.github/workflows/docker-build-backend-container-on-merge-group.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Build Backend Image on Merge Group
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
|
||||
jobs:
|
||||
build:
|
||||
# TODO: make this a matrix build like the web containers
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Backend Image Docker Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: false
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=v0.0.1
|
||||
@@ -5,33 +5,38 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
# TODO: make this a matrix build like the web containers
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v2
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
danswer/danswer-backend:${{ github.ref_name }}
|
||||
danswer/danswer-backend:latest
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.REGISTRY_IMAGE }}:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
@@ -39,6 +44,6 @@ jobs:
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }}
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
trivyignores: ./backend/.trivyignore
|
||||
|
||||
@@ -7,23 +7,24 @@ on:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v2
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
|
||||
@@ -5,38 +5,115 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
build:
|
||||
runs-on:
|
||||
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- build
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Web Image Docker Build and Push
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
danswer/danswer-web-server:${{ github.ref_name }}
|
||||
danswer/danswer-web-server:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-web-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
53
.github/workflows/docker-build-web-container-on-merge-group.yml
vendored
Normal file
53
.github/workflows/docker-build-web-container-on-merge-group.yml
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
name: Build Web Image on Merge Group
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: false
|
||||
build-args: |
|
||||
DANSWER_VERSION=v0.0.1
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -5,3 +5,5 @@
|
||||
.idea
|
||||
/deployment/data/nginx/app.conf
|
||||
.vscode/launch.json
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
|
||||
52
.vscode/env_template.txt
vendored
Normal file
52
.vscode/env_template.txt
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copy this file to .env at the base of the repo and fill in the <REPLACE THIS> values
|
||||
# This will help with development iteration speed and reduce repeat tasks for dev
|
||||
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
|
||||
|
||||
# For local dev, often user Authentication is not needed
|
||||
AUTH_TYPE=disabled
|
||||
|
||||
|
||||
# Always keep these on for Dev
|
||||
# Logs all model prompts to stdout
|
||||
LOG_DANSWER_MODEL_INTERACTIONS=True
|
||||
# More verbose logging
|
||||
LOG_LEVEL=debug
|
||||
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to answer generation
|
||||
# This step is quite heavy on token usage so we disable it for dev generally
|
||||
DISABLE_LLM_CHUNK_FILTER=True
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
|
||||
REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-3.5-turbo
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
|
||||
|
||||
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using DanswerBot
|
||||
#DANSWER_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
|
||||
#DANSWER_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
|
||||
|
||||
|
||||
# Python stuff
|
||||
PYTHONPATH=./backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY=<REPLACE THIS>
|
||||
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
28
.vscode/launch.template.jsonc
vendored
28
.vscode/launch.template.jsonc
vendored
@@ -17,6 +17,7 @@
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
@@ -28,6 +29,7 @@
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
@@ -45,8 +47,9 @@
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_ALL_MODEL_INTERACTIONS": "True",
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
@@ -63,6 +66,7 @@
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"ENABLE_MINI_CHUNK": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
@@ -77,7 +81,9 @@
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
@@ -100,6 +106,24 @@
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +72,10 @@ For convenience here's a command for it:
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
--> Note that this virtual environment MUST NOT be set up WITHIN the danswer
|
||||
directory
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
|
||||
8
LICENSE
8
LICENSE
@@ -1,6 +1,10 @@
|
||||
MIT License
|
||||
Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
Copyright (c) 2023 Yuhong Sun, Chris Weaver
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
* All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
* All third party components incorporated into the Danswer Software are licensed under the original license provided by the owner of the applicable component.
|
||||
* Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
22
README.md
22
README.md
@@ -11,7 +11,7 @@
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ" target="_blank">
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
@@ -105,5 +105,25 @@ Efficiently pulls the latest changes from:
|
||||
* Websites
|
||||
* And more ...
|
||||
|
||||
## 📚 Editions
|
||||
|
||||
There are two editions of Danswer:
|
||||
|
||||
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
|
||||
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
* Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
* Role-based access control
|
||||
* Document permission inheritance from connected sources
|
||||
* Usage analytics and query history accessible to admins
|
||||
* Whitelabeling
|
||||
* API key authentication
|
||||
* Encryption of secrets
|
||||
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
|
||||
|
||||
To try the Danswer Enterprise Edition:
|
||||
|
||||
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
|
||||
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
2
backend/.gitignore
vendored
2
backend/.gitignore
vendored
@@ -5,7 +5,7 @@ site_crawls/
|
||||
.ipynb_checkpoints/
|
||||
api_keys.py
|
||||
*ipynb
|
||||
.env
|
||||
.env*
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule*
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
LABEL com.danswer.maintainer="founders@danswer.ai"
|
||||
LABEL com.danswer.description="This image is for the backend of Danswer. It is MIT Licensed and \
|
||||
free for all to use. You can find it at https://hub.docker.com/r/danswer/danswer-backend. For \
|
||||
more details, visit https://github.com/danswer-ai/danswer."
|
||||
LABEL com.danswer.description="This image is the web/frontend container of Danswer which \
|
||||
contains code for both the Community and Enterprise editions of Danswer. If you do not \
|
||||
have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \
|
||||
Edition features outside of personal development or testing purposes. Please reach out to \
|
||||
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
# libpq-dev needed for psycopg (postgres)
|
||||
@@ -17,18 +19,32 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# zip for Vespa step futher down
|
||||
# ca-certificates for HTTPS
|
||||
RUN apt-get update && \
|
||||
apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 \
|
||||
libblkid1=2.38.1-5+deb12u1 libmount1=2.38.1-5+deb12u1 libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 && \
|
||||
apt-get install -y \
|
||||
cmake \
|
||||
curl \
|
||||
zip \
|
||||
ca-certificates \
|
||||
libgnutls30=3.7.9-2+deb12u3 \
|
||||
libblkid1=2.38.1-5+deb12u1 \
|
||||
libmount1=2.38.1-5+deb12u1 \
|
||||
libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
||||
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
-r /tmp/requirements.txt \
|
||||
-r /tmp/ee-requirements.txt && \
|
||||
pip uninstall -y py && \
|
||||
playwright install chromium && playwright install-deps chromium && \
|
||||
playwright install chromium && \
|
||||
playwright install-deps chromium && \
|
||||
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
|
||||
|
||||
# Cleanup for CVEs and size reduction
|
||||
@@ -36,14 +52,25 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \
|
||||
libldap-2.5-0 libldap-2.5-0 && \
|
||||
RUN apt-get update && \
|
||||
apt-get remove -y --allow-remove-essential \
|
||||
perl-base \
|
||||
xserver-common \
|
||||
xvfb \
|
||||
cmake \
|
||||
libldap-2.5-0 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc && \
|
||||
apt-get install -y libxmlsec1-openssl && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
@@ -53,12 +80,24 @@ nltk.download('punkt', quiet=True);"
|
||||
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
|
||||
# Enterprise Version Files
|
||||
COPY ./ee /app/ee
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# Set up application files
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
# Escape hatch
|
||||
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
|
||||
# Put logo in assets
|
||||
COPY ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
|
||||
# Default command which does nothing
|
||||
|
||||
@@ -18,14 +18,17 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('danswer/intent-model', cache_folder='/root/.cache/temp_huggingface/hub/'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_folder='/root/.cache/temp_huggingface/hub/'); \
|
||||
from transformers import TFDistilBertForSequenceClassification; \
|
||||
TFDistilBertForSequenceClassification.from_pretrained('danswer/intent-model', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
AutoTokenizer.from_pretrained('danswer/intent-model'); \
|
||||
AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
snapshot_download('danswer/intent-model'); \
|
||||
snapshot_download('intfloat/e5-base-v2'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')"
|
||||
snapshot_download('danswer/intent-model', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True, cache_folder='/root/.cache/temp_huggingface/hub/');"
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add thread specific model selection
|
||||
|
||||
Revision ID: 0568ccf46a6b
|
||||
Revises: e209dc5a8156
|
||||
Create Date: 2024-06-19 14:25:36.376046
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0568ccf46a6b"
|
||||
down_revision = "e209dc5a8156"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("current_alternate_model", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "current_alternate_model")
|
||||
@@ -0,0 +1,32 @@
|
||||
"""add search doc relevance details
|
||||
|
||||
Revision ID: 05c07bf07c00
|
||||
Revises: b896bbd0d5a7
|
||||
Create Date: 2024-07-10 17:48:15.886653
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "05c07bf07c00"
|
||||
down_revision = "b896bbd0d5a7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_doc",
|
||||
sa.Column("is_relevant", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"search_doc",
|
||||
sa.Column("relevance_explanation", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_doc", "relevance_explanation")
|
||||
op.drop_column("search_doc", "is_relevant")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add_indexing_start_to_connector
|
||||
|
||||
Revision ID: 08a1eda20fe1
|
||||
Revises: 8a87bd6ec550
|
||||
Create Date: 2024-07-23 11:12:39.462397
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "08a1eda20fe1"
|
||||
down_revision = "8a87bd6ec550"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector", sa.Column("indexing_start", sa.DateTime(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector", "indexing_start")
|
||||
@@ -0,0 +1,86 @@
|
||||
"""remove-feedback-foreignkey-constraint
|
||||
|
||||
Revision ID: 23957775e5f5
|
||||
Revises: bc9771dccadf
|
||||
Create Date: 2024-06-27 16:04:51.480437
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "23957775e5f5"
|
||||
down_revision = "bc9771dccadf"
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_feedback__chat_message_fk",
|
||||
"chat_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=True
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.alter_column(
|
||||
"document_retrieval_feedback",
|
||||
"chat_message_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=False
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_feedback__chat_message_fk",
|
||||
"chat_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"document_retrieval_feedback",
|
||||
"chat_message_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Add icon_color and icon_shape to Persona
|
||||
|
||||
Revision ID: 325975216eb3
|
||||
Revises: 91ffac7e65b3
|
||||
Create Date: 2024-07-24 21:29:31.784562
|
||||
|
||||
"""
|
||||
import random
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column, select
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "325975216eb3"
|
||||
down_revision = "91ffac7e65b3"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
colorOptions = [
|
||||
"#FF6FBF",
|
||||
"#6FB1FF",
|
||||
"#B76FFF",
|
||||
"#FFB56F",
|
||||
"#6FFF8D",
|
||||
"#FF6F6F",
|
||||
"#6FFFFF",
|
||||
]
|
||||
|
||||
|
||||
# Function to generate a random shape ensuring at least 3 of the middle 4 squares are filled
|
||||
def generate_random_shape() -> int:
|
||||
center_squares = [12, 10, 6, 14, 13, 11, 7, 15]
|
||||
center_fill = random.choice(center_squares)
|
||||
remaining_squares = [i for i in range(16) if not (center_fill & (1 << i))]
|
||||
random.shuffle(remaining_squares)
|
||||
for i in range(10 - bin(center_fill).count("1")):
|
||||
center_fill |= 1 << remaining_squares[i]
|
||||
return center_fill
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True))
|
||||
op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True))
|
||||
op.add_column("persona", sa.Column("uploaded_image_id", sa.String(), nullable=True))
|
||||
|
||||
persona = table(
|
||||
"persona",
|
||||
column("id", sa.Integer),
|
||||
column("icon_color", sa.String),
|
||||
column("icon_shape", sa.Integer),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
personas = conn.execute(select(persona.c.id))
|
||||
|
||||
for persona_id in personas:
|
||||
random_color = random.choice(colorOptions)
|
||||
random_shape = generate_random_shape()
|
||||
conn.execute(
|
||||
persona.update()
|
||||
.where(persona.c.id == persona_id[0])
|
||||
.values(icon_color=random_color, icon_shape=random_shape)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "icon_shape")
|
||||
op.drop_column("persona", "uploaded_image_id")
|
||||
op.drop_column("persona", "icon_color")
|
||||
@@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3879338f8ba1"
|
||||
down_revision = "f1c6478c3fd8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""add alternate assistant to chat message
|
||||
|
||||
Revision ID: 3a7802814195
|
||||
Revises: 23957775e5f5
|
||||
Create Date: 2024-06-05 11:18:49.966333
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3a7802814195"
|
||||
down_revision = "23957775e5f5"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_persona",
|
||||
"chat_message",
|
||||
"persona",
|
||||
["alternate_assistant_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "alternate_assistant_id")
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Rename index_origin to index_recursively
|
||||
|
||||
Revision ID: 1d6ad76d1f37
|
||||
Revises: e1392f05e840
|
||||
Create Date: 2024-08-01 12:38:54.466081
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d6ad76d1f37"
|
||||
down_revision = "e1392f05e840"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{index_recursively}',
|
||||
'true'::jsonb
|
||||
) - 'index_origin'
|
||||
WHERE connector_specific_config ? 'index_origin'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{index_origin}',
|
||||
connector_specific_config->'index_recursively'
|
||||
) - 'index_recursively'
|
||||
WHERE connector_specific_config ? 'index_recursively'
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,65 @@
|
||||
"""add cloud embedding model and update embedding_model
|
||||
|
||||
Revision ID: 44f856ae2a4a
|
||||
Revises: d716b0791ddd
|
||||
Create Date: 2024-06-28 20:01:05.927647
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "44f856ae2a4a"
|
||||
down_revision = "d716b0791ddd"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create embedding_provider table
|
||||
op.create_table(
|
||||
"embedding_provider",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("default_model_id", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
# Add cloud_provider_id to embedding_model table
|
||||
op.add_column(
|
||||
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Add foreign key constraints
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["cloud_provider_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_provider_default_model",
|
||||
"embedding_provider",
|
||||
"embedding_model",
|
||||
["default_model_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove foreign key constraints
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Remove cloud_provider_id column
|
||||
op.drop_column("embedding_model", "cloud_provider_id")
|
||||
|
||||
# Drop embedding_provider table
|
||||
op.drop_table("embedding_provider")
|
||||
@@ -0,0 +1,23 @@
|
||||
"""added is_internet to DBDoc
|
||||
|
||||
Revision ID: 4505fd7302e1
|
||||
Revises: c18cdf4b497e
|
||||
Create Date: 2024-06-18 20:46:09.095034
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4505fd7302e1"
|
||||
down_revision = "c18cdf4b497e"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True))
|
||||
op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool", "display_name")
|
||||
op.drop_column("search_doc", "is_internet")
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Add display_model_names to llm_provider
|
||||
|
||||
Revision ID: 473a1a7ca408
|
||||
Revises: 325975216eb3
|
||||
Create Date: 2024-07-25 14:31:02.002917
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "473a1a7ca408"
|
||||
down_revision = "325975216eb3"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
default_models_by_provider = {
|
||||
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],
|
||||
"bedrock": [
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
],
|
||||
"anthropic": ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620"],
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("display_model_names", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
|
||||
connection = op.get_bind()
|
||||
for provider, models in default_models_by_provider.items():
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"UPDATE llm_provider SET display_model_names = :models WHERE provider = :provider"
|
||||
),
|
||||
{"models": models, "provider": provider},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "display_model_names")
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Add support for custom tools
|
||||
|
||||
Revision ID: 48d14957fe80
|
||||
Revises: b85f02ec1308
|
||||
Create Date: 2024-06-09 14:58:19.946509
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "48d14957fe80"
|
||||
down_revision = "b85f02ec1308"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"openapi_schema",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.create_foreign_key("tool_user_fk", "tool", "user", ["user_id"], ["id"])
|
||||
|
||||
op.create_table(
|
||||
"tool_call",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("tool_id", sa.Integer(), nullable=False),
|
||||
sa.Column("tool_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"tool_arguments", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"tool_result", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"message_id", sa.Integer(), sa.ForeignKey("chat_message.id"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("tool_call")
|
||||
|
||||
op.drop_constraint("tool_user_fk", "tool", type_="foreignkey")
|
||||
op.drop_column("tool", "user_id")
|
||||
op.drop_column("tool", "openapi_schema")
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Add type to credentials
|
||||
|
||||
Revision ID: 4ea2c93919c1
|
||||
Revises: 473a1a7ca408
|
||||
Create Date: 2024-07-18 13:07:13.655895
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4ea2c93919c1"
|
||||
down_revision = "473a1a7ca408"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new 'source' column to the 'credential' table
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"source",
|
||||
sa.String(length=100), # Use String instead of Enum
|
||||
nullable=True, # Initially allow NULL values
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"name",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create a temporary table that maps each credential to a single connector source.
|
||||
# This is needed because a credential can be associated with multiple connectors,
|
||||
# but we want to assign a single source to each credential.
|
||||
# We use DISTINCT ON to ensure we only get one row per credential_id.
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TEMPORARY TABLE temp_connector_credential AS
|
||||
SELECT DISTINCT ON (cc.credential_id)
|
||||
cc.credential_id,
|
||||
c.source AS connector_source
|
||||
FROM connector_credential_pair cc
|
||||
JOIN connector c ON cc.connector_id = c.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Update the 'source' column in the 'credential' table
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE credential cred
|
||||
SET source = COALESCE(
|
||||
(SELECT connector_source
|
||||
FROM temp_connector_credential temp
|
||||
WHERE cred.id = temp.credential_id),
|
||||
'NOT_APPLICABLE'
|
||||
)
|
||||
"""
|
||||
)
|
||||
# If no exception was raised, alter the column
|
||||
op.alter_column("credential", "source", nullable=True) # TODO modify
|
||||
# # ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("credential", "source")
|
||||
op.drop_column("credential", "name")
|
||||
@@ -0,0 +1,68 @@
|
||||
"""More Descriptive Filestore
|
||||
|
||||
Revision ID: 70f00c45c0f2
|
||||
Revises: 3879338f8ba1
|
||||
Create Date: 2024-05-17 17:51:41.926893
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "70f00c45c0f2"
|
||||
down_revision = "3879338f8ba1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("file_store", sa.Column("display_name", sa.String(), nullable=True))
|
||||
op.add_column(
|
||||
"file_store",
|
||||
sa.Column(
|
||||
"file_origin",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="connector", # Default to connector
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"file_store",
|
||||
sa.Column(
|
||||
"file_type", sa.String(), nullable=False, server_default="text/plain"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"file_store",
|
||||
sa.Column(
|
||||
"file_metadata",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE file_store
|
||||
SET file_origin = CASE
|
||||
WHEN file_name LIKE 'chat__%' THEN 'chat_upload'
|
||||
ELSE 'connector'
|
||||
END,
|
||||
file_name = CASE
|
||||
WHEN file_name LIKE 'chat__%' THEN SUBSTR(file_name, 7)
|
||||
ELSE file_name
|
||||
END,
|
||||
file_type = CASE
|
||||
WHEN file_name LIKE 'chat__%' THEN 'image/png'
|
||||
ELSE 'text/plain'
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("file_store", "file_metadata")
|
||||
op.drop_column("file_store", "file_type")
|
||||
op.drop_column("file_store", "file_origin")
|
||||
op.drop_column("file_store", "display_name")
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add_llm_group_permissions_control
|
||||
|
||||
Revision ID: 795b20b85b4b
|
||||
Revises: 05c07bf07c00
|
||||
Create Date: 2024-07-19 11:54:35.701558
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "795b20b85b4b"
|
||||
down_revision = "05c07bf07c00"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_provider__user_group",
|
||||
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["llm_provider_id"],
|
||||
["llm_provider.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("llm_provider__user_group")
|
||||
op.drop_column("llm_provider", "is_public")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""added slack_auto_filter
|
||||
|
||||
Revision ID: 7aea705850d5
|
||||
Revises: 4505fd7302e1
|
||||
Create Date: 2024-07-10 11:01:23.581015
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "7aea705850d5"
|
||||
down_revision = "4505fd7302e1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"slack_bot_config",
|
||||
sa.Column("enable_auto_filters", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE slack_bot_config SET enable_auto_filters = FALSE WHERE enable_auto_filters IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"slack_bot_config",
|
||||
"enable_auto_filters",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("slack_bot_config", "enable_auto_filters")
|
||||
@@ -0,0 +1,103 @@
|
||||
"""associate index attempts with ccpair
|
||||
|
||||
Revision ID: 8a87bd6ec550
|
||||
Revises: 4ea2c93919c1
|
||||
Create Date: 2024-07-22 15:15:52.558451
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8a87bd6ec550"
|
||||
down_revision = "4ea2c93919c1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new connector_credential_pair_id column
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
# Create a foreign key constraint to the connector_credential_pair table
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
"connector_credential_pair",
|
||||
["connector_credential_pair_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Populate the new connector_credential_pair_id column using existing connector_id and credential_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_credential_pair_id =
|
||||
CASE
|
||||
WHEN ia.credential_id IS NULL THEN
|
||||
(SELECT id FROM connector_credential_pair
|
||||
WHERE connector_id = ia.connector_id
|
||||
LIMIT 1)
|
||||
ELSE
|
||||
(SELECT id FROM connector_credential_pair
|
||||
WHERE connector_id = ia.connector_id
|
||||
AND credential_id = ia.credential_id)
|
||||
END
|
||||
WHERE ia.connector_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the new connector_credential_pair_id column non-nullable
|
||||
op.alter_column("index_attempt", "connector_credential_pair_id", nullable=False)
|
||||
|
||||
# Drop the old connector_id and credential_id columns
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
|
||||
# Update the index to use connector_credential_pair_id
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_credential_pair_id", "time_created"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old connector_id and credential_id columns
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the old connector_id and credential_id columns using the connector_credential_pair_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_id = ccp.connector_id, credential_id = ccp.credential_id
|
||||
FROM connector_credential_pair ccp
|
||||
WHERE ia.connector_credential_pair_id = ccp.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the old connector_id and credential_id columns non-nullable
|
||||
op.alter_column("index_attempt", "connector_id", nullable=False)
|
||||
op.alter_column("index_attempt", "credential_id", nullable=False)
|
||||
|
||||
# Drop the new connector_credential_pair_id column
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_column("index_attempt", "connector_credential_pair_id")
|
||||
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_id", "credential_id", "time_created"],
|
||||
)
|
||||
26
backend/alembic/versions/91ffac7e65b3_add_expiry_time.py
Normal file
26
backend/alembic/versions/91ffac7e65b3_add_expiry_time.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""add expiry time
|
||||
|
||||
Revision ID: 91ffac7e65b3
|
||||
Revises: bc9771dccadf
|
||||
Create Date: 2024-06-24 09:39:56.462242
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91ffac7e65b3"
|
||||
down_revision = "795b20b85b4b"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "oidc_expiry")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add chosen_assistants to User table
|
||||
|
||||
Revision ID: a3bfd0d64902
|
||||
Revises: ec85f2b3c544
|
||||
Create Date: 2024-05-26 17:22:24.834741
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3bfd0d64902"
|
||||
down_revision = "ec85f2b3c544"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
@@ -16,7 +16,6 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -29,11 +28,9 @@ def upgrade() -> None:
|
||||
),
|
||||
nullable=True,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -46,4 +43,3 @@ def downgrade() -> None:
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""fix-file-type-migration
|
||||
|
||||
Revision ID: b85f02ec1308
|
||||
Revises: a3bfd0d64902
|
||||
Create Date: 2024-05-31 18:09:26.658164
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b85f02ec1308"
|
||||
down_revision = "a3bfd0d64902"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE file_store
|
||||
SET file_origin = UPPER(file_origin)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Let's not break anything on purpose :)
|
||||
pass
|
||||
@@ -0,0 +1,23 @@
|
||||
"""backfill is_internet data to False
|
||||
|
||||
Revision ID: b896bbd0d5a7
|
||||
Revises: 44f856ae2a4a
|
||||
Create Date: 2024-07-16 15:21:05.718571
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b896bbd0d5a7"
|
||||
down_revision = "44f856ae2a4a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("UPDATE search_doc SET is_internet = FALSE WHERE is_internet IS NULL")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,51 @@
|
||||
"""create usage reports table
|
||||
|
||||
Revision ID: bc9771dccadf
|
||||
Revises: 0568ccf46a6b
|
||||
Create Date: 2024-06-18 10:04:26.800282
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bc9771dccadf"
|
||||
down_revision = "0568ccf46a6b"
|
||||
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"usage_reports",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("report_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"requestor_user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("period_from", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("period_to", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["report_name"],
|
||||
["file_store.file_name"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["requestor_user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("usage_reports")
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Add standard_answer tables
|
||||
|
||||
Revision ID: c18cdf4b497e
|
||||
Revises: 3a7802814195
|
||||
Create Date: 2024-06-06 15:15:02.000648
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c18cdf4b497e"
|
||||
down_revision = "3a7802814195"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"standard_answer",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("keyword", sa.String(), nullable=False),
|
||||
sa.Column("answer", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("keyword"),
|
||||
)
|
||||
op.create_table(
|
||||
"standard_answer_category",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
op.create_table(
|
||||
"standard_answer__standard_answer_category",
|
||||
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
|
||||
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_category_id"],
|
||||
["standard_answer_category.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_id"],
|
||||
["standard_answer.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("standard_answer_id", "standard_answer_category_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
sa.Column("slack_bot_config_id", sa.Integer(), nullable=False),
|
||||
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["slack_bot_config_id"],
|
||||
["slack_bot_config.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_category_id"],
|
||||
["standard_answer_category.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("slack_bot_config_id", "standard_answer_category_id"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("slack_thread_id", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "slack_thread_id")
|
||||
|
||||
op.drop_table("slack_bot_config__standard_answer_category")
|
||||
op.drop_table("standard_answer__standard_answer_category")
|
||||
op.drop_table("standard_answer_category")
|
||||
op.drop_table("standard_answer")
|
||||
@@ -0,0 +1,45 @@
|
||||
"""combined slack id fields
|
||||
|
||||
Revision ID: d716b0791ddd
|
||||
Revises: 7aea705850d5
|
||||
Create Date: 2024-07-10 17:57:45.630550
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d716b0791ddd"
|
||||
down_revision = "7aea705850d5"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_bot_config
|
||||
SET channel_config = jsonb_set(
|
||||
channel_config,
|
||||
'{respond_member_group_list}',
|
||||
coalesce(channel_config->'respond_team_member_list', '[]'::jsonb) ||
|
||||
coalesce(channel_config->'respond_slack_group_list', '[]'::jsonb)
|
||||
) - 'respond_team_member_list' - 'respond_slack_group_list'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_bot_config
|
||||
SET channel_config = jsonb_set(
|
||||
jsonb_set(
|
||||
channel_config - 'respond_member_group_list',
|
||||
'{respond_team_member_list}',
|
||||
'[]'::jsonb
|
||||
),
|
||||
'{respond_slack_group_list}',
|
||||
'[]'::jsonb
|
||||
)
|
||||
"""
|
||||
)
|
||||
58
backend/alembic/versions/e1392f05e840_added_input_prompts.py
Normal file
58
backend/alembic/versions/e1392f05e840_added_input_prompts.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Added input prompts
|
||||
|
||||
Revision ID: e1392f05e840
|
||||
Revises: 08a1eda20fe1
|
||||
Create Date: 2024-07-13 19:09:22.556224
|
||||
|
||||
"""
|
||||
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e1392f05e840"
|
||||
down_revision = "08a1eda20fe1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
@@ -0,0 +1,22 @@
|
||||
"""added-prune-frequency
|
||||
|
||||
Revision ID: e209dc5a8156
|
||||
Revises: 48d14957fe80
|
||||
Create Date: 2024-06-16 16:02:35.273231
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "e209dc5a8156"
|
||||
down_revision = "48d14957fe80"
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("connector", sa.Column("prune_freq", sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector", "prune_freq")
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Remove Last Attempt Status from CC Pair
|
||||
|
||||
Revision ID: ec85f2b3c544
|
||||
Revises: 3879338f8ba1
|
||||
Create Date: 2024-05-23 21:39:46.126010
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ec85f2b3c544"
|
||||
down_revision = "70f00c45c0f2"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "last_attempt_status")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_attempt_status",
|
||||
sa.VARCHAR(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
2
backend/assets/.gitignore
vendored
Normal file
2
backend/assets/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
||||
@@ -1,6 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.models import User
|
||||
@@ -19,7 +20,7 @@ def _get_access_for_documents(
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(user_ids, is_public)
|
||||
document_id: DocumentAccess.build(user_ids, [], is_public)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
}
|
||||
|
||||
@@ -38,12 +39,6 @@ def get_access_for_documents(
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
||||
@@ -1,20 +1,30 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_user_group
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess:
|
||||
user_ids: set[str] # stringified UUIDs
|
||||
user_groups: set[str] # names of user groups associated with this document
|
||||
is_public: bool
|
||||
|
||||
def to_acl(self) -> list[str]:
|
||||
return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
return (
|
||||
[prefix_user(user_id) for user_id in self.user_ids]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess":
|
||||
def build(
|
||||
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
|
||||
) -> "DocumentAccess":
|
||||
return cls(
|
||||
user_ids={str(user_id) for user_id in user_ids if user_id},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
10
backend/danswer/access/utils.py
Normal file
10
backend/danswer/access/utils.py
Normal file
@@ -0,0 +1,10 @@
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
|
||||
|
||||
def prefix_user_group(user_group_name: str) -> str:
|
||||
"""Prefixes a user group name to eliminate collision with user IDs.
|
||||
This assumes that user ids are prefixed with a different prefix."""
|
||||
return f"group:{user_group_name}"
|
||||
21
backend/danswer/auth/invited_users.py
Normal file
21
backend/danswer/auth/invited_users.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
|
||||
USER_STORE_KEY = "INVITED_USERS"
|
||||
|
||||
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_dynamic_config_store()
|
||||
return cast(list, store.load(USER_STORE_KEY))
|
||||
except ConfigNotFoundError:
|
||||
return list()
|
||||
|
||||
|
||||
def write_invited_users(emails: list[str]) -> int:
|
||||
store = get_dynamic_config_store()
|
||||
store.store(USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
40
backend/danswer/auth/noauth_user.py
Normal file
40
backend/danswer/auth/noauth_user.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.dynamic_configs.store import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.store import DynamicConfigStore
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.server.manage.models import UserPreferences
|
||||
|
||||
|
||||
NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
|
||||
|
||||
|
||||
def set_no_auth_user_preferences(
|
||||
store: DynamicConfigStore, preferences: UserPreferences
|
||||
) -> None:
|
||||
store.store(NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
|
||||
|
||||
|
||||
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
|
||||
try:
|
||||
preferences_data = cast(
|
||||
Mapping[str, Any], store.load(NO_AUTH_USER_PREFERENCES_KEY)
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except ConfigNotFoundError:
|
||||
return UserPreferences(chosen_assistants=None)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
|
||||
return UserInfo(
|
||||
id="__no_auth_user__",
|
||||
email="anonymous@danswer.ai",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
)
|
||||
@@ -9,6 +9,12 @@ class UserRole(str, Enum):
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
LIVE = "live"
|
||||
INVITED = "invited"
|
||||
DEACTIVATED = "deactivated"
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
role: UserRole
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
@@ -27,6 +28,7 @@ from fastapi_users.openapi import OpenAPIResponseType
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
@@ -46,22 +48,24 @@ from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
USER_WHITELIST_FILE = "/home/danswer_whitelist.txt"
|
||||
_user_whitelist: list[str] | None = None
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
|
||||
@@ -92,22 +96,16 @@ def user_needs_to_be_verified() -> bool:
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def get_user_whitelist() -> list[str]:
|
||||
global _user_whitelist
|
||||
if _user_whitelist is None:
|
||||
if os.path.exists(USER_WHITELIST_FILE):
|
||||
with open(USER_WHITELIST_FILE, "r") as file:
|
||||
_user_whitelist = [line.strip() for line in file]
|
||||
else:
|
||||
_user_whitelist = []
|
||||
|
||||
return _user_whitelist
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if (whitelist and email not in whitelist) or not email:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
whitelist = get_user_whitelist()
|
||||
if (whitelist and email not in whitelist) or not email:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
@@ -159,11 +157,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> models.UP:
|
||||
verify_email_in_whitelist(user_create.email)
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0:
|
||||
if user_count == 0 or user_create.email in get_default_admin_user_emails():
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
@@ -185,7 +183,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
verify_email_in_whitelist(account_email)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
return await super().oauth_callback( # type: ignore
|
||||
user = await super().oauth_callback( # type: ignore
|
||||
oauth_name=oauth_name,
|
||||
access_token=access_token,
|
||||
account_id=account_id,
|
||||
@@ -197,6 +195,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
|
||||
# NOTE: google oauth expires after 1hr. We don't want to force the user to
|
||||
# re-authenticate that frequently, so for now we'll just ignore this for
|
||||
# google oauth users
|
||||
if expires_at and AUTH_TYPE != AuthType.GOOGLE_OAUTH:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
@@ -239,10 +245,12 @@ cookie_transport = CookieTransport(
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
strategy = DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="database",
|
||||
@@ -339,6 +347,12 @@ async def double_check_user(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -357,4 +371,5 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not an admin.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@@ -4,13 +4,23 @@ from typing import cast
|
||||
from celery import Celery # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.celery_utils import should_prune_cc_pair
|
||||
from danswer.background.celery.celery_utils import should_sync_doc_set
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import POSTGRES_CELERY_APP_NAME
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
@@ -22,8 +32,6 @@ from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import SYNC_DB_API
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_live_task_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
@@ -31,7 +39,9 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
connection_string = build_connection_string(db_api=SYNC_DB_API)
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=POSTGRES_CELERY_APP_NAME
|
||||
)
|
||||
celery_broker_url = f"sqla+{connection_string}"
|
||||
celery_backend_url = f"db+{connection_string}"
|
||||
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
|
||||
@@ -68,7 +78,9 @@ def cleanup_connector_credential_pair_task(
|
||||
f"{connector_id} and Credential ID: {credential_id} does not exist."
|
||||
)
|
||||
|
||||
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(cc_pair)
|
||||
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair, db_session=db_session
|
||||
)
|
||||
if deletion_attempt_disallowed_reason:
|
||||
raise ValueError(deletion_attempt_disallowed_reason)
|
||||
|
||||
@@ -88,6 +100,74 @@ def cleanup_connector_credential_pair_task(
|
||||
raise e
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_cc_prune_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all docuement IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
logger.warning(f"ccpair not found for {connector_id} {credential_id}")
|
||||
return
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
cc_pair.connector.source,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
db_session,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector
|
||||
)
|
||||
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
if len(doc_ids_to_remove) == 0:
|
||||
logger.info(
|
||||
f"No docs to prune from {cc_pair.connector.source} connector"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
|
||||
)
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to run pruning for connector id {connector_id} due to {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_document_set_sync_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_document_set_task(document_set_id: int) -> None:
|
||||
@@ -175,32 +255,48 @@ def sync_document_set_task(document_set_id: int) -> None:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_document_sets_sync_task() -> None:
|
||||
"""Runs periodically to check if any document sets are out of sync
|
||||
Creates a task to sync the set if needed"""
|
||||
"""Runs periodically to check if any sync tasks should be run and adds them
|
||||
to the queue"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
if not document_set.is_up_to_date:
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_live_task_not_timed_out(
|
||||
latest_sync, db_session
|
||||
):
|
||||
logger.info(
|
||||
f"Document set '{document_set.id}' is already syncing. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
if should_sync_doc_set(document_set, db_session):
|
||||
logger.info(f"Syncing the {document_set.name} document set")
|
||||
sync_document_set_task.apply_async(
|
||||
kwargs=dict(document_set_id=document_set.id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_for_prune_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_prune_task() -> None:
|
||||
"""Runs periodically to check if any prune tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
|
||||
for cc_pair in all_cc_pairs:
|
||||
if should_prune_cc_pair(
|
||||
connector=cc_pair.connector,
|
||||
credential=cc_pair.credential,
|
||||
db_session=db_session,
|
||||
):
|
||||
logger.info(f"Pruning the {cc_pair.connector.name} connector")
|
||||
|
||||
prune_documents_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
@@ -210,3 +306,11 @@ celery_app.conf.beat_schedule = {
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
}
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-for-prune": {
|
||||
"task": "check_for_prune_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
}
|
||||
)
|
||||
9
backend/danswer/background/celery/celery_run.py
Normal file
9
backend/danswer/background/celery/celery_run.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Entry point for running celery worker / celery beat."""
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
celery_app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.celery_app", "celery_app"
|
||||
)
|
||||
@@ -1,8 +1,32 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import get_latest_task_by_type
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_deletion_status(
|
||||
@@ -21,3 +45,92 @@ def get_deletion_status(
|
||||
credential_id=credential_id,
|
||||
status=task_state.status,
|
||||
)
|
||||
|
||||
|
||||
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
|
||||
if document_set.is_up_to_date:
|
||||
return False
|
||||
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
|
||||
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
|
||||
return False
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
return True
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
if not connector.prune_freq:
|
||||
return False
|
||||
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=connector.id, credential_id=credential.id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
|
||||
if not last_pruning_task:
|
||||
time_since_initialization = current_db_time - connector.time_created
|
||||
if time_since_initialization.total_seconds() >= connector.prune_freq:
|
||||
return True
|
||||
return False
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
)
|
||||
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
return False
|
||||
|
||||
time_since_last_pruning = current_db_time - last_pruning_task.start_time
|
||||
return time_since_last_pruning.total_seconds() >= connector.prune_freq
|
||||
|
||||
|
||||
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
doc_batch_generator = None
|
||||
if isinstance(runnable_connector, IdConnector):
|
||||
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
if doc_batch_generator:
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
@@ -41,7 +41,7 @@ logger = setup_logger()
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def _delete_connector_credential_pair_batch(
|
||||
def delete_connector_credential_pair_batch(
|
||||
document_ids: list[str],
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
@@ -169,7 +169,7 @@ def delete_connector_credential_pair(
|
||||
if not documents:
|
||||
break
|
||||
|
||||
_delete_connector_credential_pair_batch(
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=[document.id for document in documents],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
|
||||
@@ -105,7 +105,9 @@ class SimpleJobClient:
|
||||
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
||||
self._cleanup_completed_jobs()
|
||||
if len(self.jobs) >= self.n_workers:
|
||||
logger.debug("No available workers to run job")
|
||||
logger.debug(
|
||||
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
||||
)
|
||||
return None
|
||||
|
||||
job_id = self.job_id_counter
|
||||
|
||||
@@ -6,11 +6,7 @@ from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.connector_deletion import (
|
||||
_delete_connector_credential_pair_batch,
|
||||
)
|
||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from danswer.configs.app_configs import DISABLE_DOCUMENT_CLEANUP
|
||||
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@@ -21,12 +17,10 @@ from danswer.connectors.models import InputType
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -37,6 +31,7 @@ from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -46,7 +41,7 @@ def _get_document_generator(
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> tuple[GenerateDocumentsOutput, bool]:
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
|
||||
@@ -54,31 +49,31 @@ def _get_document_generator(
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
task = attempt.connector.input_type
|
||||
task = attempt.connector_credential_pair.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector, new_credential_json = instantiate_connector(
|
||||
attempt.connector.source,
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector_credential_pair.connector.source,
|
||||
task,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential.credential_json,
|
||||
attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
attempt.connector_credential_pair.credential,
|
||||
db_session,
|
||||
)
|
||||
if new_credential_json is not None:
|
||||
backend_update_credential_json(
|
||||
attempt.credential, new_credential_json, db_session
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
disable_connector(attempt.connector_credential_pair.connector.id, db_session)
|
||||
raise e
|
||||
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
is_listing_complete = True
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
if (
|
||||
attempt.connector_credential_pair.connector_id is None
|
||||
or attempt.connector_credential_pair.connector_id is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
@@ -88,13 +83,12 @@ def _get_document_generator(
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time.timestamp(), end=end_time.timestamp()
|
||||
)
|
||||
is_listing_complete = False
|
||||
|
||||
else:
|
||||
# Event types cannot be handled by a background type
|
||||
raise RuntimeError(f"Invalid task type: {task}")
|
||||
|
||||
return doc_batch_generator, is_listing_complete
|
||||
return doc_batch_generator
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
@@ -107,7 +101,6 @@ def _run_indexing(
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
db_embedding_model = index_attempt.embedding_model
|
||||
index_name = db_embedding_model.index_name
|
||||
|
||||
@@ -125,6 +118,8 @@ def _run_indexing(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@@ -135,16 +130,21 @@ def _run_indexing(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_connector = index_attempt.connector
|
||||
db_credential = index_attempt.credential
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
|
||||
last_successful_index_time = (
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
db_connector.indexing_start.timestamp()
|
||||
if index_attempt.from_beginning and db_connector.indexing_start is not None
|
||||
else (
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -160,19 +160,19 @@ def _run_indexing(
|
||||
source_type=db_connector.source,
|
||||
)
|
||||
):
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
doc_batch_generator, is_listing_complete = _get_document_generator(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
)
|
||||
|
||||
try:
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
doc_batch_generator = _get_document_generator(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
for doc_batch in doc_batch_generator:
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
@@ -197,7 +197,7 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
documents=doc_batch,
|
||||
document_batch=doc_batch,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
@@ -224,46 +224,12 @@ def _run_indexing(
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
if is_listing_complete and not DISABLE_DOCUMENT_CLEANUP:
|
||||
# clean up all documents from the index that have not been returned from the connector
|
||||
all_indexed_document_ids = {
|
||||
d.id
|
||||
for d in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
)
|
||||
}
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids
|
||||
)
|
||||
logger.debug(
|
||||
f"Cleaning up {len(doc_ids_to_remove)} documents that are not contained in the newest connector state"
|
||||
)
|
||||
|
||||
# delete docs from cc-pair and receive the number of completely deleted docs in return
|
||||
_delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=len(doc_ids_to_remove),
|
||||
)
|
||||
|
||||
run_end_dt = window_end
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
@@ -292,9 +258,8 @@ def _run_indexing(
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
connector_id=index_attempt.connector_credential_pair.connector.id,
|
||||
credential_id=index_attempt.connector_credential_pair.credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
raise e
|
||||
@@ -309,7 +274,6 @@ def _run_indexing(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.SUCCESS,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
@@ -332,6 +296,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
if attempt is None:
|
||||
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
|
||||
|
||||
@@ -342,26 +307,19 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
)
|
||||
|
||||
# only commit once, to make sure this all happens in a single transaction
|
||||
mark_attempt_in_progress__no_commit(attempt)
|
||||
is_primary = attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector.id,
|
||||
credential_id=attempt.credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
)
|
||||
else:
|
||||
db_session.commit()
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
|
||||
return attempt
|
||||
|
||||
|
||||
def run_indexing_entrypoint(index_attempt_id: int) -> None:
|
||||
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
try:
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
@@ -372,17 +330,17 @@ def run_indexing_entrypoint(index_attempt_id: int) -> None:
|
||||
attempt = _prepare_index_attempt(db_session, index_attempt_id)
|
||||
|
||||
logger.info(
|
||||
f"Running indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"Running indexing attempt for connector: '{attempt.connector_credential_pair.connector.name}', "
|
||||
f"with config: '{attempt.connector_credential_pair.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt)
|
||||
|
||||
logger.info(
|
||||
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"Completed indexing attempt for connector: '{attempt.connector_credential_pair.connector.name}', "
|
||||
f"with config: '{attempt.connector_credential_pair.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
|
||||
@@ -22,6 +22,15 @@ def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
return f"sync_doc_set_{document_set_id}"
|
||||
|
||||
|
||||
def name_cc_prune_task(
|
||||
connector_id: int | None = None, credential_id: int | None = None
|
||||
) -> str:
|
||||
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
if not connector_id or not credential_id:
|
||||
task_name = "prune_connector_credential_pair"
|
||||
return task_name
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
|
||||
@@ -16,17 +16,19 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import Connector
|
||||
@@ -35,8 +37,10 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
@@ -66,20 +70,26 @@ def _should_create_new_indexing(
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if model.status == IndexModelStatus.FUTURE and not last_index:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
if model.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# secondary indexes should not index again after success
|
||||
# or else the model will never be able to swap
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
else:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is disabled, don't index
|
||||
# NOTE: during an embedding model switch over, we ignore this
|
||||
# and index the disabled connectors as well (which is why this if
|
||||
# statement is below the first condition above)
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if connector.disabled:
|
||||
return False
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
@@ -111,25 +121,14 @@ def _mark_run_failed(
|
||||
"""Marks the `index_attempt` row as failed + updates the `
|
||||
connector_credential_pair` to reflect that the run failed"""
|
||||
logger.warning(
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
|
||||
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
db_session=db_session,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
if (
|
||||
index_attempt.connector_id is not None
|
||||
and index_attempt.credential_id is not None
|
||||
and index_attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
):
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector_id,
|
||||
credential_id=index_attempt.credential_id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
)
|
||||
|
||||
|
||||
"""Main funcs"""
|
||||
@@ -142,7 +141,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
ongoing: set[tuple[int | None, int | None, int]] = set()
|
||||
ongoing: set[tuple[int | None, int]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
@@ -155,8 +154,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
continue
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_id,
|
||||
attempt.credential_id,
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.embedding_model_id,
|
||||
)
|
||||
)
|
||||
@@ -166,41 +164,26 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
if secondary_embedding_model is not None:
|
||||
embedding_models.append(secondary_embedding_model)
|
||||
|
||||
all_connectors = fetch_connectors(db_session)
|
||||
for connector in all_connectors:
|
||||
for association in connector.credentials:
|
||||
for model in embedding_models:
|
||||
credential = association.credential
|
||||
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in all_connector_credential_pairs:
|
||||
for model in embedding_models:
|
||||
# Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
if (cc_pair.id, model.id) in ongoing:
|
||||
continue
|
||||
|
||||
# Check if there is an ongoing indexing attempt for this connector + credential pair
|
||||
if (connector.id, credential.id, model.id) in ongoing:
|
||||
continue
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
connector=cc_pair.connector,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
connector=connector,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
|
||||
# CC-Pair will have the status that it should for the primary index
|
||||
# Will be re-sync-ed once the indices are swapped
|
||||
if model.status == IndexModelStatus.PRESENT:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
attempt_status=IndexingStatus.NOT_STARTED,
|
||||
)
|
||||
create_index_attempt(cc_pair.id, model.id, db_session)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
@@ -292,6 +275,8 @@ def kickoff_indexing_jobs(
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
with Session(engine) as db_session:
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.embedding_model)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
@@ -309,7 +294,7 @@ def kickoff_indexing_jobs(
|
||||
if embedding_model is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector is None:
|
||||
if attempt.connector_credential_pair.connector is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
)
|
||||
@@ -318,7 +303,7 @@ def kickoff_indexing_jobs(
|
||||
attempt, db_session, failure_reason="Connector is null"
|
||||
)
|
||||
continue
|
||||
if attempt.credential is None:
|
||||
if attempt.connector_credential_pair.credential is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
)
|
||||
@@ -330,39 +315,52 @@ def kickoff_indexing_jobs(
|
||||
|
||||
if use_secondary_index:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint, attempt.id, pure=False
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
else:
|
||||
run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
|
||||
if run:
|
||||
secondary_str = "(secondary index) " if use_secondary_index else ""
|
||||
logger.info(
|
||||
f"Kicked off {secondary_str}"
|
||||
f"indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"indexing attempt for connector: '{attempt.connector_credential_pair.connector.name}', "
|
||||
f"with config: '{attempt.connector_credential_pair.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.connector_credential_pair.credential_id}'"
|
||||
)
|
||||
existing_jobs_copy[attempt.id] = run
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
def update_loop(
|
||||
delay: int = 10,
|
||||
num_workers: int = NUM_INDEXING_WORKERS,
|
||||
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
embedding_model=db_embedding_model,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
@@ -377,7 +375,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
cluster_secondary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
n_workers=num_secondary_workers,
|
||||
threads_per_worker=1,
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
@@ -387,15 +385,10 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
client_primary.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
|
||||
with Session(engine) as db_session:
|
||||
# Previous version did not always clean up cc-pairs well leaving some connectors undeleteable
|
||||
# This ensures that bad states get cleaned up
|
||||
mark_all_in_progress_cc_pairs_failed(db_session)
|
||||
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -426,6 +419,9 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
|
||||
def update__main() -> None:
|
||||
set_is_ee_based_on_env_variable()
|
||||
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
|
||||
|
||||
logger.info("Starting Indexing Loop")
|
||||
update_loop()
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -9,45 +8,34 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc:
|
||||
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inf_chunk.document_id,
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
# This one is using the combined content of all the chunks of the section
|
||||
# In default settings, this is the same as just the content of base chunk
|
||||
content=inf_chunk.combined_content,
|
||||
blurb=inf_chunk.blurb,
|
||||
semantic_identifier=inf_chunk.semantic_identifier,
|
||||
source_type=inf_chunk.source_type,
|
||||
metadata=inf_chunk.metadata,
|
||||
updated_at=inf_chunk.updated_at,
|
||||
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
|
||||
source_links=inf_chunk.source_links,
|
||||
content=inference_section.combined_content,
|
||||
blurb=inference_section.center_chunk.blurb,
|
||||
semantic_identifier=inference_section.center_chunk.semantic_identifier,
|
||||
source_type=inference_section.center_chunk.source_type,
|
||||
metadata=inference_section.center_chunk.metadata,
|
||||
updated_at=inference_section.center_chunk.updated_at,
|
||||
link=inference_section.center_chunk.source_links[0]
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
)
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> dict[str, int]:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
for chunk in chunks:
|
||||
if chunk.document_id not in order_mapping:
|
||||
order_mapping[chunk.document_id] = current
|
||||
current += 1
|
||||
|
||||
return order_mapping
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
@@ -56,6 +44,7 @@ def create_chat_chain(
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
prefetch_tool_calls=prefetch_tool_calls,
|
||||
)
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
|
||||
24
backend/danswer/chat/input_prompts.yaml
Normal file
24
backend/danswer/chat/input_prompts.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
input_prompts:
|
||||
- id: -5
|
||||
prompt: "Elaborate"
|
||||
content: "Elaborate on the above, give me a more in depth explanation."
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -4
|
||||
prompt: "Reword"
|
||||
content: "Help me rewrite the following politely and concisely for professional communication:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -3
|
||||
prompt: "Email"
|
||||
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -2
|
||||
prompt: "Debug"
|
||||
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
@@ -1,18 +1,18 @@
|
||||
from typing import cast
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.db.chat import get_prompt_by_name
|
||||
from danswer.db.chat import upsert_persona
|
||||
from danswer.db.chat import upsert_prompt
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Prompt as PromptDBModel
|
||||
from danswer.db.persona import get_prompt_by_name
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.persona import upsert_prompt
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def load_personas_from_yaml(
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] | None = [
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
]
|
||||
@@ -58,22 +58,24 @@ def load_personas_from_yaml(
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
if not doc_sets:
|
||||
doc_sets = None
|
||||
|
||||
prompt_set_names = persona["prompts"]
|
||||
if not prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] | None = None
|
||||
doc_set_ids: list[int] | None = None
|
||||
if doc_sets:
|
||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||
else:
|
||||
prompts = [
|
||||
doc_set_ids = None
|
||||
|
||||
prompt_ids: list[int] | None = None
|
||||
prompt_set_names = persona["prompts"]
|
||||
if prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] = [
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
if not prompts:
|
||||
prompts = None
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
p_id = persona.get("id")
|
||||
upsert_persona(
|
||||
@@ -88,20 +90,45 @@ def load_personas_from_yaml(
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompts=cast(list[PromptDBModel] | None, prompts),
|
||||
document_sets=doc_sets,
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
default_persona=True,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
|
||||
with open(input_prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_input_prompts = data.get("input_prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(prompt_yaml)
|
||||
load_personas_from_yaml(personas_yaml)
|
||||
load_input_prompts_from_yaml(input_prompts_yaml)
|
||||
|
||||
@@ -42,11 +42,21 @@ class QADocsResponse(RetrievalDocs):
|
||||
return initial_dict
|
||||
|
||||
|
||||
# Second chunk of info for streaming QA
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
relevant_chunk_indices: list[int]
|
||||
|
||||
|
||||
class RelevanceChunk(BaseModel):
|
||||
# TODO make this document level. Also slight misnomer here as this is actually
|
||||
# done at the section level currently rather than the chunk
|
||||
relevant: bool | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class LLMRelevanceSummaryResponse(BaseModel):
|
||||
relevance_summaries: dict[str, RelevanceChunk]
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
@@ -106,12 +116,18 @@ class ImageGenerationDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
class CustomToolResponse(BaseModel):
|
||||
response: dict
|
||||
tool_name: str
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
)
|
||||
|
||||
|
||||
@@ -37,7 +37,8 @@ personas:
|
||||
# - "Engineer Onboarding"
|
||||
# - "Benefits"
|
||||
document_sets: []
|
||||
|
||||
icon_shape: 23013
|
||||
icon_color: "#6FB1FF"
|
||||
|
||||
- id: 1
|
||||
name: "GPT"
|
||||
@@ -50,6 +51,8 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 50910
|
||||
icon_color: "#FF6F6F"
|
||||
|
||||
|
||||
- id: 2
|
||||
@@ -63,3 +66,6 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 45519
|
||||
icon_color: "#6FFF8D"
|
||||
|
||||
|
||||
@@ -7,15 +7,19 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
@@ -29,7 +33,9 @@ from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
@@ -42,25 +48,44 @@ from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.factory import get_tool_cls
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -93,14 +118,21 @@ def _handle_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
selected_search_docs: list[DbSearchDoc] | None,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
||||
dedupe_docs: bool = False,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
|
||||
response_sumary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
dropped_inds = None
|
||||
if not selected_search_docs:
|
||||
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
|
||||
|
||||
deduped_docs = top_docs
|
||||
if dedupe_docs:
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
|
||||
for top_doc in top_docs
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
else:
|
||||
reference_db_search_docs = selected_search_docs
|
||||
@@ -120,35 +152,81 @@ def _handle_search_tool_response_summary(
|
||||
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
dropped_inds,
|
||||
)
|
||||
|
||||
|
||||
def _check_should_force_search(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
) -> ForceUseTool | None:
|
||||
if (
|
||||
new_msg_req.query_override
|
||||
or (
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
)
|
||||
or new_msg_req.search_doc_ids
|
||||
):
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override
|
||||
else None
|
||||
)
|
||||
# if we are using selected docs, just put something here so the Tool doesn't need
|
||||
# to build its own args via an LLM call
|
||||
if new_msg_req.search_doc_ids:
|
||||
args = {"query": new_msg_req.message}
|
||||
def _handle_internet_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
||||
internet_search_response = cast(InternetSearchResponse, packet.response)
|
||||
server_search_docs = internet_search_response_to_search_docs(
|
||||
internet_search_response
|
||||
)
|
||||
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool.name(),
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in server_search_docs
|
||||
]
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=internet_search_response.revised_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=SearchType.HYBRID,
|
||||
applied_source_filters=[],
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
)
|
||||
|
||||
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
|
||||
) -> ForceUseTool:
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
)
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
|
||||
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
# Currently, the internet search tool does not support query override
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override and tool_name == SearchTool._NAME
|
||||
else None
|
||||
)
|
||||
|
||||
if new_msg_req.file_descriptors:
|
||||
# If user has uploaded files they're using, don't run any of the search tools
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name)
|
||||
|
||||
should_force_search = any(
|
||||
[
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
]
|
||||
)
|
||||
|
||||
if should_force_search:
|
||||
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
|
||||
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
|
||||
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
|
||||
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
@@ -159,6 +237,7 @@ ChatPacket = (
|
||||
| DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -176,13 +255,13 @@ def stream_chat_message_objects(
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
# user message (e.g. this can only be used for the chat-seeding flow).
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
|
||||
"""
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -198,7 +277,18 @@ def stream_chat_message_objects(
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
persona = chat_session.persona
|
||||
alternate_assistant_id = new_msg_req.alternate_assistant_id
|
||||
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
if prompt_id is None and persona.prompts:
|
||||
@@ -210,13 +300,21 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
try:
|
||||
llm = get_llm_for_persona(
|
||||
persona, new_msg_req.llm_override or chat_session.llm_override
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
|
||||
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
llm_provider = llm.config.model_provider
|
||||
llm_model_name = llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
@@ -240,6 +338,7 @@ def stream_chat_message_objects(
|
||||
else:
|
||||
parent_message = root_message
|
||||
|
||||
user_message = None
|
||||
if not use_existing_user_message:
|
||||
# Create new message at the right place in the tree and update the parent's child pointer
|
||||
# Don't commit yet until we verify the chat message chain
|
||||
@@ -250,10 +349,7 @@ def stream_chat_message_objects(
|
||||
message=message_text,
|
||||
token_count=len(llm_tokenizer_encode_func(message_text)),
|
||||
message_type=MessageType.USER,
|
||||
files=[
|
||||
{"id": str(file_id), "type": ChatFileType.IMAGE}
|
||||
for file_id in new_msg_req.file_ids
|
||||
],
|
||||
files=None, # Need to attach later for optimization to only load files once in parallel
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
@@ -268,8 +364,8 @@ def stream_chat_message_objects(
|
||||
"Be sure to update the chat pointers before calling this."
|
||||
)
|
||||
|
||||
# Save now to save the latest chat message
|
||||
db_session.commit()
|
||||
# NOTE: do not commit user message - it will be committed when the
|
||||
# assistant message is successfully generated
|
||||
else:
|
||||
# re-create linear history of messages
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
@@ -282,14 +378,36 @@ def stream_chat_message_objects(
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
# leads to worst search quality
|
||||
if not history_msgs:
|
||||
new_msg_req.query_override = (
|
||||
new_msg_req.query_override or new_msg_req.message
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(history_msgs, new_msg_req.file_ids, db_session)
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
)
|
||||
latest_query_files = [
|
||||
file for file in files if file.file_id in new_msg_req.file_ids
|
||||
file
|
||||
for file in files
|
||||
if file.file_id in [f["id"] for f in new_msg_req.file_descriptors]
|
||||
]
|
||||
|
||||
if user_message:
|
||||
attach_files_to_chat_message(
|
||||
chat_message=user_message,
|
||||
files=[
|
||||
new_file.to_file_descriptor() for new_file in latest_query_files
|
||||
],
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
selected_db_search_docs = None
|
||||
selected_llm_docs: list[LlmDoc] | None = None
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
if reference_doc_ids:
|
||||
identifier_tuples = get_doc_query_identifiers_from_model(
|
||||
search_doc_ids=reference_doc_ids,
|
||||
@@ -299,8 +417,8 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# Generates full documents currently
|
||||
# May extend to include chunk ranges
|
||||
selected_llm_docs = inference_documents_from_ids(
|
||||
# May extend to use sections instead in the future
|
||||
selected_sections = inference_sections_from_ids(
|
||||
doc_identifiers=identifier_tuples,
|
||||
document_index=document_index,
|
||||
)
|
||||
@@ -341,77 +459,123 @@ def stream_chat_message_objects(
|
||||
# rephrased_query=,
|
||||
# token_count=,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
# error=,
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
if not final_msg.prompt:
|
||||
raise RuntimeError("No Prompt found")
|
||||
|
||||
prompt_config = PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
prompt_config = (
|
||||
PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
)
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
|
||||
persona_tool_classes = [
|
||||
get_tool_cls(tool, db_session) for tool in persona.tools
|
||||
]
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
img_generation_llm_config = llm.config
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name=openai_provider.default_model_name,
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(api_key=bing_api_key)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(
|
||||
db_tool_model.openapi_schema
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
persona_tool_classes
|
||||
tools, llm_tokenizer
|
||||
)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
|
||||
# NOTE: for now, only support SearchTool and ImageGenerationTool
|
||||
# in the future, will support arbitrary user-defined tools
|
||||
search_tool: SearchTool | None = None
|
||||
tools: list[Tool] = []
|
||||
for tool_cls in persona_tool_classes:
|
||||
if tool_cls.__name__ == SearchTool.__name__:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm_config=llm.config,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_docs=selected_llm_docs,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
)
|
||||
tools.append(search_tool)
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
dalle_key = None
|
||||
if llm and llm.config.api_key and llm.config.model_provider == "openai":
|
||||
dalle_key = llm.config.api_key
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
dalle_key = openai_provider.api_key
|
||||
tools.append(ImageGenerationTool(api_key=dalle_key))
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
question=final_msg.message,
|
||||
@@ -425,33 +589,57 @@ def stream_chat_message_objects(
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
or get_llm_for_persona(
|
||||
persona, new_msg_req.llm_override or chat_session.llm_override
|
||||
or get_main_llm_from_tuple(
|
||||
get_llms_for_persona(
|
||||
persona=persona,
|
||||
llm_override=(
|
||||
new_msg_req.llm_override or chat_session.llm_override
|
||||
),
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
)
|
||||
),
|
||||
message_history=[
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
],
|
||||
tools=tools,
|
||||
force_use_tool=_check_should_force_search(new_msg_req),
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
qa_docs_response = None
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet, db_session, selected_db_search_docs
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
chunk_indices = packet.response
|
||||
|
||||
if reference_db_search_docs is not None and dropped_indices:
|
||||
chunk_indices = drop_llm_indices(
|
||||
llm_indices=chunk_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=packet.response
|
||||
relevant_chunk_indices=chunk_indices
|
||||
)
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
@@ -468,16 +656,41 @@ def stream_chat_message_objects(
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.exception("Failed to process chat message")
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
# This will be the issue 99% of the time
|
||||
yield StreamingError(error="LLM failed to respond, have you set your API key?")
|
||||
# Don't leak the API key
|
||||
error_msg = str(e)
|
||||
if llm.config.api_key and llm.config.api_key.lower() in error_msg.lower():
|
||||
error_msg = (
|
||||
f"LLM failed to respond. Invalid API "
|
||||
f"key error from '{llm.config.model_provider}'."
|
||||
)
|
||||
|
||||
yield StreamingError(error=error_msg)
|
||||
# Cancel the transaction so that no messages are saved
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
# Post-LLM answer processing
|
||||
@@ -490,6 +703,11 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
@@ -500,7 +718,18 @@ def stream_chat_message_objects(
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=db_citations,
|
||||
error=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else [],
|
||||
)
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
@@ -519,6 +748,7 @@ def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as db_session:
|
||||
objects = stream_chat_message_objects(
|
||||
@@ -526,6 +756,7 @@ def stream_chat_message(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.dict())
|
||||
|
||||
@@ -8,6 +8,7 @@ prompts:
|
||||
# System Prompt (as shown in UI)
|
||||
system: >
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide
|
||||
grounded, accurate, and concise answers to diverse queries.
|
||||
@@ -21,8 +22,9 @@ prompts:
|
||||
|
||||
I have not read or seen any of the documents and do not want to read them.
|
||||
|
||||
If there are no relevant documents, refer to the chat history and existing knowledge.
|
||||
If there are no relevant documents, refer to the chat history and your internal knowledge.
|
||||
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
|
||||
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
|
||||
# Format looks like: "October 16, 2023 14:30"
|
||||
datetime_aware: true
|
||||
# Prompts the LLM to include citations in the for [1], [2] etc.
|
||||
@@ -32,7 +34,16 @@ prompts:
|
||||
|
||||
- name: "OnlyLLM"
|
||||
description: "Chat directly with the LLM!"
|
||||
system: "You are a helpful assistant."
|
||||
system: >
|
||||
You are a helpful AI assistant. The current date is DANSWER_DATETIME_REPLACEMENT
|
||||
|
||||
|
||||
You give concise responses to very simple questions, but provide more thorough responses to
|
||||
more complex and open-ended questions.
|
||||
|
||||
|
||||
You are happy to help with writing, analysis, question answering, math, coding and all sorts
|
||||
of other tasks. You use markdown where reasonable and also for coding.
|
||||
task: ""
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
@@ -43,10 +54,11 @@ prompts:
|
||||
system: >
|
||||
You are a text summarizing assistant that highlights the most important knowledge from the
|
||||
context provided, prioritizing the information that relates to the user query.
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
You ARE NOT creative and always stick to the provided documents.
|
||||
If there are no documents, refer to the conversation history.
|
||||
|
||||
|
||||
IMPORTANT: YOU ONLY SUMMARIZE THE IMPORTANT INFORMATION FROM THE PROVIDED DOCUMENTS,
|
||||
NEVER USE YOUR OWN KNOWLEDGE.
|
||||
task: >
|
||||
@@ -61,7 +73,8 @@ prompts:
|
||||
description: "Recites information from retrieved context! Least creative but most safe!"
|
||||
system: >
|
||||
Quote and cite relevant information from provided context based on the user query.
|
||||
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
You only provide quotes that are EXACT substrings from provided documents!
|
||||
|
||||
If there are no documents provided,
|
||||
|
||||
@@ -4,6 +4,7 @@ import urllib.parse
|
||||
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DocumentIndexType
|
||||
from danswer.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
|
||||
#####
|
||||
# App Configs
|
||||
@@ -45,13 +46,14 @@ DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
# information. This provides an extra layer of security on top of Postgres access controls
|
||||
# and is available in Danswer EE
|
||||
ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET")
|
||||
ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or ""
|
||||
|
||||
# Turn off mask if admin users should see full credentials for data connectors.
|
||||
MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
@@ -160,6 +162,11 @@ WEB_CONNECTOR_OAUTH_CLIENT_SECRET = os.environ.get("WEB_CONNECTOR_OAUTH_CLIENT_S
|
||||
WEB_CONNECTOR_OAUTH_TOKEN_URL = os.environ.get("WEB_CONNECTOR_OAUTH_TOKEN_URL")
|
||||
WEB_CONNECTOR_VALIDATE_URLS = os.environ.get("WEB_CONNECTOR_VALIDATE_URLS")
|
||||
|
||||
HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
|
||||
"HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY",
|
||||
HtmlBasedConnectorTransformLinksStrategy.STRIP,
|
||||
)
|
||||
|
||||
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
|
||||
== "true"
|
||||
@@ -178,6 +185,12 @@ CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Save pages labels as Danswer metadata tags
|
||||
# The reason to skip this would be to reduce the number of calls to Confluence due to rate limit concerns
|
||||
CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -188,6 +201,10 @@ GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
GITHUB_CONNECTOR_BASE_URL = os.environ.get("GITHUB_CONNECTOR_BASE_URL") or None
|
||||
|
||||
GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||
)
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -195,6 +212,22 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
||||
|
||||
ALLOW_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
|
||||
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
|
||||
)
|
||||
|
||||
# comma delimited list of zendesk article labels to skip indexing for
|
||||
ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get(
|
||||
"ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", ""
|
||||
).split(",")
|
||||
|
||||
|
||||
#####
|
||||
# Indexing Configs
|
||||
@@ -215,19 +248,20 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
|
||||
# fairly large amount of memory in order to increase substantially, since
|
||||
# each worker loads the embedding models into memory.
|
||||
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
|
||||
CHUNK_OVERLAP = 0
|
||||
NUM_SECONDARY_INDEXING_WORKERS = int(
|
||||
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
|
||||
)
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
MINI_CHUNK_SIZE = 150
|
||||
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
# Timeout to wait for job's last update before killing it, in hours
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
|
||||
# If set to true, then will not clean up documents that "no longer exist" when running Load connectors
|
||||
DISABLE_DOCUMENT_CLEANUP = (
|
||||
os.environ.get("DISABLE_DOCUMENT_CLEANUP", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
@@ -242,15 +276,20 @@ JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
CURRENT_PROCESS_IS_AN_INDEXING_JOB = (
|
||||
os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true"
|
||||
)
|
||||
# Logs every model prompt and output, mostly used for development or exploration purposes
|
||||
# Sets LiteLLM to verbose logging
|
||||
LOG_ALL_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
# Logs Danswer only model interactions like prompts, responses, messages etc.
|
||||
LOG_DANSWER_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
# If set to `true` will enable additional logs about Vespa query performance
|
||||
# (time spent on finding the right docs + time spent fetching summaries from disk)
|
||||
LOG_VESPA_TIMING_INFORMATION = (
|
||||
os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true"
|
||||
)
|
||||
LOG_ENDPOINT_LATENCY = os.environ.get("LOG_ENDPOINT_LATENCY", "").lower() == "true"
|
||||
# Anonymous usage telemetry
|
||||
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
|
||||
|
||||
@@ -263,3 +302,15 @@ TOKEN_BUDGET_GLOBALLY_ENABLED = (
|
||||
CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
os.environ.get("CUSTOM_ANSWER_VALIDITY_CONDITIONS", "[]")
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
# NOTE: this should only be enabled if you have purchased an enterprise license.
|
||||
# if you're interested in an enterprise license, please reach out to us at
|
||||
# founders@danswer.ai OR message Chris Weaver or Yuhong Sun in the Danswer
|
||||
# Slack community (https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
|
||||
ENTERPRISE_EDITION_ENABLED = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -3,9 +3,13 @@ import os
|
||||
|
||||
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
NUM_RERANKED_RESULTS = 15
|
||||
# Used for LLM filtering and reranking
|
||||
# We want this to be approximately the number of results we want to show on the first page
|
||||
# It cannot be too large due to cost and latency implications
|
||||
NUM_RERANKED_RESULTS = 20
|
||||
|
||||
# May be less depending on model
|
||||
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
|
||||
@@ -25,9 +29,10 @@ BASE_RECENCY_DECAY = 0.5
|
||||
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
# Currently this next one is not configurable via env
|
||||
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
|
||||
DISABLE_LLM_FILTER_EXTRACTION = (
|
||||
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
|
||||
)
|
||||
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
|
||||
# Note this is not in any of the deployment configs yet
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_CHUNK_FILTER = (
|
||||
@@ -43,8 +48,6 @@ DISABLE_LLM_QUERY_REPHRASE = (
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||
INCLUDE_METADATA = False
|
||||
# Keyword Search Drop Stopwords
|
||||
# If user has changed the default model, would most likely be to use a multilingual
|
||||
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
||||
@@ -64,9 +67,31 @@ TITLE_CONTENT_RATIO = max(
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
# For example "English,French,Spanish", be sure to use the "," separator
|
||||
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
|
||||
LANGUAGE_HINT = "\n" + (
|
||||
os.environ.get("LANGUAGE_HINT")
|
||||
or "IMPORTANT: Respond in the same language as my query!"
|
||||
)
|
||||
LANGUAGE_CHAT_NAMING_HINT = (
|
||||
os.environ.get("LANGUAGE_CHAT_NAMING_HINT")
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
# Additionally, some LLM providers have strict rate limits which may prohibit
|
||||
# sending many API requests at once (as is done in agentic search).
|
||||
DISABLE_AGENTIC_SEARCH = (
|
||||
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
|
||||
).lower() == "true"
|
||||
|
||||
|
||||
# Stops streaming answers back to the UI if this pattern is seen:
|
||||
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
# The backend logic for this being True isn't fully supported yet
|
||||
HARD_DELETE_CHATS = False
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
|
||||
@@ -19,6 +19,7 @@ DOCUMENT_SETS = "document_sets"
|
||||
TIME_FILTER = "time_filter"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
METADATA_SUFFIX = "metadata_suffix"
|
||||
MATCH_HIGHLIGHTS = "match_highlights"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
@@ -41,17 +42,15 @@ DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
TOKEN_BUDGET = "token_budget"
|
||||
TOKEN_BUDGET_TIME_PERIOD = "token_budget_time_period"
|
||||
ENABLE_TOKEN_BUDGET = "enable_token_budget"
|
||||
TOKEN_BUDGET_SETTINGS = "token_budget_settings"
|
||||
|
||||
# For chunking/processing chunks
|
||||
TITLE_SEPARATOR = "\n\r\n"
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
# For combining attributes, doesn't have to be unique/perfect to work
|
||||
INDEX_SEPARATOR = "==="
|
||||
|
||||
# For File Connector Metadata override file
|
||||
DANSWER_METADATA_FILENAME = ".danswer_metadata.json"
|
||||
|
||||
# Messages
|
||||
DISABLED_GEN_AI_MSG = (
|
||||
@@ -60,6 +59,14 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
POSTGRES_CELERY_APP_NAME = "celery"
|
||||
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
@@ -93,9 +100,30 @@ class DocumentSource(str, Enum):
|
||||
GOOGLE_SITES = "google_sites"
|
||||
ZENDESK = "zendesk"
|
||||
LOOPIO = "loopio"
|
||||
DROPBOX = "dropbox"
|
||||
SHAREPOINT = "sharepoint"
|
||||
TEAMS = "teams"
|
||||
SALESFORCE = "salesforce"
|
||||
DISCOURSE = "discourse"
|
||||
AXERO = "axero"
|
||||
CLICKUP = "clickup"
|
||||
MEDIAWIKI = "mediawiki"
|
||||
WIKIPEDIA = "wikipedia"
|
||||
S3 = "s3"
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
R2 = "r2"
|
||||
S3 = "s3"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
|
||||
# Special case, for internet search
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
class DocumentIndexType(str, Enum):
|
||||
@@ -111,6 +139,11 @@ class AuthType(str, Enum):
|
||||
SAML = "saml"
|
||||
|
||||
|
||||
class QAFeedbackType(str, Enum):
|
||||
LIKE = "like" # User likes the answer, used for metrics
|
||||
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
||||
|
||||
|
||||
class SearchFeedbackType(str, Enum):
|
||||
ENDORSE = "endorse" # boost this document for all future queries
|
||||
REJECT = "reject" # down-boost this document for all future queries
|
||||
@@ -130,3 +163,11 @@ class TokenRateLimitScope(str, Enum):
|
||||
USER = "user"
|
||||
USER_GROUP = "user_group"
|
||||
GLOBAL = "global"
|
||||
|
||||
|
||||
class FileOrigin(str, Enum):
|
||||
CHAT_UPLOAD = "chat_upload"
|
||||
CHAT_IMAGE_GEN = "chat_image_gen"
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
OTHER = "other"
|
||||
|
||||
@@ -47,10 +47,6 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
)
|
||||
# Auto detect query options like time cutoff or heavily favor recently updated docs
|
||||
DISABLE_DANSWER_BOT_FILTER_DETECT = (
|
||||
os.environ.get("DISABLE_DANSWER_BOT_FILTER_DETECT", "").lower() == "true"
|
||||
)
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
# This is the default for all DanswerBot channels unless the channel is configured individually
|
||||
@@ -73,3 +69,7 @@ DANSWER_BOT_MAX_WAIT_TIME = int(os.environ.get("DANSWER_BOT_MAX_WAIT_TIME") or 1
|
||||
DANSWER_BOT_FEEDBACK_REMINDER = int(
|
||||
os.environ.get("DANSWER_BOT_FEEDBACK_REMINDER") or 0
|
||||
)
|
||||
# Set to True to rephrase the Slack users messages
|
||||
DANSWER_BOT_REPHRASE_MESSAGE = (
|
||||
os.environ.get("DANSWER_BOT_REPHRASE_MESSAGE", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ import os
|
||||
# The useable models configured as below must be SentenceTransformer compatible
|
||||
# NOTE: DO NOT CHANGE SET THESE UNLESS YOU KNOW WHAT YOU ARE DOING
|
||||
# IDEALLY, YOU SHOULD CHANGE EMBEDDING MODELS VIA THE UI
|
||||
DEFAULT_DOCUMENT_ENCODER_MODEL = "intfloat/e5-base-v2"
|
||||
DEFAULT_DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1"
|
||||
DOCUMENT_ENCODER_MODEL = (
|
||||
os.environ.get("DOCUMENT_ENCODER_MODEL") or DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
)
|
||||
@@ -34,13 +34,13 @@ OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS = False
|
||||
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
|
||||
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
|
||||
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
# For score display purposes, only way is to know the expected ranges
|
||||
CROSS_ENCODER_RANGE_MAX = 12
|
||||
CROSS_ENCODER_RANGE_MIN = -12
|
||||
CROSS_ENCODER_RANGE_MAX = 1
|
||||
CROSS_ENCODER_RANGE_MIN = 0
|
||||
|
||||
# Unused currently, can't be used with the current default encoder model due to its output range
|
||||
SEARCH_DISTANCE_CUTOFF = 0
|
||||
@@ -100,7 +100,7 @@ DISABLE_LITELLM_STREAMING = (
|
||||
).lower() == "true"
|
||||
|
||||
# extra headers to pass to LiteLLM
|
||||
LITELLM_EXTRA_HEADERS = None
|
||||
LITELLM_EXTRA_HEADERS: dict[str, str] | None = None
|
||||
_LITELLM_EXTRA_HEADERS_RAW = os.environ.get("LITELLM_EXTRA_HEADERS")
|
||||
if _LITELLM_EXTRA_HEADERS_RAW:
|
||||
try:
|
||||
@@ -113,3 +113,18 @@ if _LITELLM_EXTRA_HEADERS_RAW:
|
||||
logger.error(
|
||||
"Failed to parse LITELLM_EXTRA_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
|
||||
# if specified, will pass through request headers to the call to the LLM
|
||||
LITELLM_PASS_THROUGH_HEADERS: list[str] | None = None
|
||||
_LITELLM_PASS_THROUGH_HEADERS_RAW = os.environ.get("LITELLM_PASS_THROUGH_HEADERS")
|
||||
if _LITELLM_PASS_THROUGH_HEADERS_RAW:
|
||||
try:
|
||||
LITELLM_PASS_THROUGH_HEADERS = json.loads(_LITELLM_PASS_THROUGH_HEADERS_RAW)
|
||||
except Exception:
|
||||
# need to import here to avoid circular imports
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
logger.error(
|
||||
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
|
||||
0
backend/danswer/connectors/blob/__init__.py
Normal file
0
backend/danswer/connectors/blob/__init__.py
Normal file
277
backend/danswer/connectors/blob/connector.py
Normal file
277
backend/danswer/connectors/blob/connector.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
from mypy_boto3_s3 import S3Client
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import BlobType
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
bucket_type: str,
|
||||
bucket_name: str,
|
||||
prefix: str = "",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.bucket_type: BlobType = BlobType(bucket_type)
|
||||
self.bucket_name = bucket_name
|
||||
self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
|
||||
self.batch_size = batch_size
|
||||
self.s3_client: Optional[S3Client] = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Checks for boto3 credentials based on the bucket type.
|
||||
(1) R2: Access Key ID, Secret Access Key, Account ID
|
||||
(2) S3: AWS Access Key ID, AWS Secret Access Key
|
||||
(3) GOOGLE_CLOUD_STORAGE: Access Key ID, Secret Access Key, Project ID
|
||||
(4) OCI_STORAGE: Namespace, Region, Access Key ID, Secret Access Key
|
||||
|
||||
For each bucket type, the method initializes the appropriate S3 client:
|
||||
- R2: Uses Cloudflare R2 endpoint with S3v4 signature
|
||||
- S3: Creates a standard boto3 S3 client
|
||||
- GOOGLE_CLOUD_STORAGE: Uses Google Cloud Storage endpoint
|
||||
- OCI_STORAGE: Uses Oracle Cloud Infrastructure Object Storage endpoint
|
||||
|
||||
Raises ConnectorMissingCredentialError if required credentials are missing.
|
||||
Raises ValueError for unsupported bucket types.
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"Loading credentials for {self.bucket_name} or type {self.bucket_type}"
|
||||
)
|
||||
|
||||
if self.bucket_type == BlobType.R2:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Cloudflare R2")
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=f"https://{credentials['account_id']}.r2.cloudflarestorage.com",
|
||||
aws_access_key_id=credentials["r2_access_key_id"],
|
||||
aws_secret_access_key=credentials["r2_secret_access_key"],
|
||||
region_name="auto",
|
||||
config=Config(signature_version="s3v4"),
|
||||
)
|
||||
|
||||
elif self.bucket_type == BlobType.S3:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=credentials["aws_access_key_id"],
|
||||
aws_secret_access_key=credentials["aws_secret_access_key"],
|
||||
)
|
||||
self.s3_client = session.client("s3")
|
||||
|
||||
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
||||
if not all(
|
||||
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
||||
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url="https://storage.googleapis.com",
|
||||
aws_access_key_id=credentials["access_key_id"],
|
||||
aws_secret_access_key=credentials["secret_access_key"],
|
||||
region_name="auto",
|
||||
)
|
||||
|
||||
elif self.bucket_type == BlobType.OCI_STORAGE:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["namespace", "region", "access_key_id", "secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
|
||||
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=f"https://{credentials['namespace']}.compat.objectstorage.{credentials['region']}.oraclecloud.com",
|
||||
aws_access_key_id=credentials["access_key_id"],
|
||||
aws_secret_access_key=credentials["secret_access_key"],
|
||||
region_name=credentials["region"],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
||||
|
||||
return None
|
||||
|
||||
def _download_object(self, key: str) -> bytes:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
object = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
|
||||
return object["Body"].read()
|
||||
|
||||
# NOTE: Left in as may be useful for one-off access to documents and sharing across orgs.
|
||||
# def _get_presigned_url(self, key: str) -> str:
|
||||
# if self.s3_client is None:
|
||||
# raise ConnectorMissingCredentialError("Blog storage")
|
||||
|
||||
# url = self.s3_client.generate_presigned_url(
|
||||
# "get_object",
|
||||
# Params={"Bucket": self.bucket_name, "Key": key},
|
||||
# ExpiresIn=self.presign_length,
|
||||
# )
|
||||
# return url
|
||||
|
||||
def _get_blob_link(self, key: str) -> str:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
|
||||
if self.bucket_type == BlobType.R2:
|
||||
account_id = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
|
||||
return f"https://{account_id}.r2.cloudflarestorage.com/{self.bucket_name}/{key}"
|
||||
|
||||
elif self.bucket_type == BlobType.S3:
|
||||
region = self.s3_client.meta.region_name
|
||||
return f"https://{self.bucket_name}.s3.{region}.amazonaws.com/{key}"
|
||||
|
||||
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
||||
return f"https://storage.cloud.google.com/{self.bucket_name}/{key}"
|
||||
|
||||
elif self.bucket_type == BlobType.OCI_STORAGE:
|
||||
namespace = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
|
||||
region = self.s3_client.meta.region_name
|
||||
return f"https://objectstorage.{region}.oraclecloud.com/n/{namespace}/b/{self.bucket_name}/o/{key}"
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
||||
|
||||
def _yield_blob_objects(
|
||||
self,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blog storage")
|
||||
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
|
||||
batch: list[Document] = []
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
|
||||
for obj in page["Contents"]:
|
||||
if obj["Key"].endswith("/"):
|
||||
continue
|
||||
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
|
||||
if not start <= last_modified <= end:
|
||||
continue
|
||||
|
||||
downloaded_file = self._download_object(obj["Key"])
|
||||
link = self._get_blob_link(obj["Key"])
|
||||
name = os.path.basename(obj["Key"])
|
||||
|
||||
try:
|
||||
text = extract_file_text(
|
||||
name,
|
||||
BytesIO(downloaded_file),
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}",
|
||||
sections=[Section(link=link, text=text)],
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=name,
|
||||
doc_updated_at=last_modified,
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error decoding object {obj['Key']} as UTF-8: {e}"
|
||||
)
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
logger.info("Loading blob objects")
|
||||
return self._yield_blob_objects(
|
||||
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
end=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blog storage")
|
||||
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
for batch in self._yield_blob_objects(start_datetime, end_datetime):
|
||||
yield batch
|
||||
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
credentials_dict = {
|
||||
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
}
|
||||
|
||||
# Initialize the connector
|
||||
connector = BlobStorageConnector(
|
||||
bucket_type=os.environ.get("BUCKET_TYPE") or "s3",
|
||||
bucket_name=os.environ.get("BUCKET_NAME") or "test",
|
||||
prefix="",
|
||||
)
|
||||
|
||||
try:
|
||||
connector.load_credentials(credentials_dict)
|
||||
document_batch_generator = connector.load_from_state()
|
||||
for document_batch in document_batch_generator:
|
||||
print("First batch of documents:")
|
||||
for doc in document_batch:
|
||||
print(f"Document ID: {doc.id}")
|
||||
print(f"Semantic Identifier: {doc.semantic_identifier}")
|
||||
print(f"Source: {doc.source}")
|
||||
print(f"Updated At: {doc.doc_updated_at}")
|
||||
print("Sections:")
|
||||
for section in doc.sections:
|
||||
print(f" - Link: {section.link}")
|
||||
print(f" - Text: {section.text[:100]}...")
|
||||
print("---")
|
||||
break
|
||||
|
||||
except ConnectorMissingCredentialError as e:
|
||||
print(f"Error: {e}")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
0
backend/danswer/connectors/clickup/__init__.py
Normal file
0
backend/danswer/connectors/clickup/__init__.py
Normal file
216
backend/danswer/connectors/clickup/connector.py
Normal file
216
backend/danswer/connectors/clickup/connector.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
|
||||
|
||||
CLICKUP_API_BASE_URL = "https://api.clickup.com/api/v2"
|
||||
|
||||
|
||||
class ClickupConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
api_token: str | None = None,
|
||||
team_id: str | None = None,
|
||||
connector_type: str | None = None,
|
||||
connector_ids: list[str] | None = None,
|
||||
retrieve_task_comments: bool = True,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.api_token = api_token
|
||||
self.team_id = team_id
|
||||
self.connector_type = connector_type if connector_type else "workspace"
|
||||
self.connector_ids = connector_ids
|
||||
self.retrieve_task_comments = retrieve_task_comments
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.api_token = credentials["clickup_api_token"]
|
||||
self.team_id = credentials["clickup_team_id"]
|
||||
return None
|
||||
|
||||
@retry_builder()
|
||||
@rate_limit_builder(max_calls=100, period=60)
|
||||
def _make_request(self, endpoint: str, params: Optional[dict] = None) -> Any:
|
||||
if not self.api_token:
|
||||
raise ConnectorMissingCredentialError("Clickup")
|
||||
|
||||
headers = {"Authorization": self.api_token}
|
||||
|
||||
response = requests.get(
|
||||
f"{CLICKUP_API_BASE_URL}/{endpoint}", headers=headers, params=params
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
|
||||
def _get_task_comments(self, task_id: str) -> list[Section]:
|
||||
url_endpoint = f"/task/{task_id}/comment"
|
||||
response = self._make_request(url_endpoint)
|
||||
comments = [
|
||||
Section(
|
||||
link=f'https://app.clickup.com/t/{task_id}?comment={comment_dict["id"]}',
|
||||
text=comment_dict["comment_text"],
|
||||
)
|
||||
for comment_dict in response["comments"]
|
||||
]
|
||||
|
||||
return comments
|
||||
|
||||
def _get_all_tasks_filtered(
|
||||
self,
|
||||
start: int | None = None,
|
||||
end: int | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
page: int = 0
|
||||
params = {
|
||||
"include_markdown_description": "true",
|
||||
"include_closed": "true",
|
||||
"page": page,
|
||||
}
|
||||
|
||||
if start is not None:
|
||||
params["date_updated_gt"] = start
|
||||
if end is not None:
|
||||
params["date_updated_lt"] = end
|
||||
|
||||
if self.connector_type == "list":
|
||||
params["list_ids[]"] = self.connector_ids
|
||||
elif self.connector_type == "folder":
|
||||
params["project_ids[]"] = self.connector_ids
|
||||
elif self.connector_type == "space":
|
||||
params["space_ids[]"] = self.connector_ids
|
||||
|
||||
url_endpoint = f"/team/{self.team_id}/task"
|
||||
|
||||
while True:
|
||||
response = self._make_request(url_endpoint, params)
|
||||
|
||||
page += 1
|
||||
params["page"] = page
|
||||
|
||||
for task in response["tasks"]:
|
||||
document = Document(
|
||||
id=task["id"],
|
||||
source=DocumentSource.CLICKUP,
|
||||
semantic_identifier=task["name"],
|
||||
doc_updated_at=(
|
||||
datetime.fromtimestamp(
|
||||
round(float(task["date_updated"]) / 1000, 3)
|
||||
).replace(tzinfo=timezone.utc)
|
||||
),
|
||||
primary_owners=[
|
||||
BasicExpertInfo(
|
||||
display_name=task["creator"]["username"],
|
||||
email=task["creator"]["email"],
|
||||
)
|
||||
],
|
||||
secondary_owners=[
|
||||
BasicExpertInfo(
|
||||
display_name=assignee["username"],
|
||||
email=assignee["email"],
|
||||
)
|
||||
for assignee in task["assignees"]
|
||||
],
|
||||
title=task["name"],
|
||||
sections=[
|
||||
Section(
|
||||
link=task["url"],
|
||||
text=(
|
||||
task["markdown_description"]
|
||||
if "markdown_description" in task
|
||||
else task["description"]
|
||||
),
|
||||
)
|
||||
],
|
||||
metadata={
|
||||
"id": task["id"],
|
||||
"status": task["status"]["status"],
|
||||
"list": task["list"]["name"],
|
||||
"project": task["project"]["name"],
|
||||
"folder": task["folder"]["name"],
|
||||
"space_id": task["space"]["id"],
|
||||
"tags": [tag["name"] for tag in task["tags"]],
|
||||
"priority": (
|
||||
task["priority"]["priority"]
|
||||
if "priority" in task and task["priority"] is not None
|
||||
else ""
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
extra_fields = [
|
||||
"date_created",
|
||||
"date_updated",
|
||||
"date_closed",
|
||||
"date_done",
|
||||
"due_date",
|
||||
]
|
||||
for extra_field in extra_fields:
|
||||
if extra_field in task and task[extra_field] is not None:
|
||||
document.metadata[extra_field] = task[extra_field]
|
||||
|
||||
if self.retrieve_task_comments:
|
||||
document.sections.extend(self._get_task_comments(task["id"]))
|
||||
|
||||
doc_batch.append(document)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if response.get("last_page") is True or len(response["tasks"]) < 100:
|
||||
break
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
if self.api_token is None:
|
||||
raise ConnectorMissingCredentialError("Clickup")
|
||||
|
||||
return self._get_all_tasks_filtered(None, None)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.api_token is None:
|
||||
raise ConnectorMissingCredentialError("Clickup")
|
||||
|
||||
return self._get_all_tasks_filtered(int(start * 1000), int(end * 1000))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
clickup_connector = ClickupConnector()
|
||||
|
||||
clickup_connector.load_credentials(
|
||||
{
|
||||
"clickup_api_token": os.environ["clickup_api_token"],
|
||||
"clickup_team_id": os.environ["clickup_team_id"],
|
||||
}
|
||||
)
|
||||
latest_docs = clickup_connector.load_from_state()
|
||||
|
||||
for doc in latest_docs:
|
||||
print(doc)
|
||||
@@ -1,3 +1,5 @@
|
||||
import io
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Collection
|
||||
from datetime import datetime
|
||||
@@ -13,6 +15,7 @@ from requests import HTTPError
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -27,22 +30,25 @@ from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.html_utils import format_document_soup
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Potential Improvements
|
||||
# 1. If wiki page instead of space, do a search of all the children of the page instead of index all in the space
|
||||
# 2. Include attachments, etc
|
||||
# 3. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
# 1. Include attachments, etc
|
||||
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
https://danswer.atlassian.net/wiki/spaces/1234abcd/overview
|
||||
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview
|
||||
|
||||
wiki_base is https://danswer.atlassian.net/wiki
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
@@ -51,18 +57,25 @@ def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
|
||||
+ parsed_url.netloc
|
||||
+ parsed_url.path.split("/spaces")[0]
|
||||
)
|
||||
space = parsed_url.path.split("/")[3]
|
||||
return wiki_base, space
|
||||
|
||||
path_parts = parsed_url.path.split("/")
|
||||
space = path_parts[3]
|
||||
|
||||
page_id = path_parts[5] if len(path_parts) > 5 else ""
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str]:
|
||||
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
https://danswer.ai/confluence/display/1234abcd/overview
|
||||
URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page https://danswer.ai/confluence/display/1234abcd/overview
|
||||
wiki_base is https://danswer.ai/confluence
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
# /display/ is always right before the space and at the end of the base url
|
||||
# /display/ is always right before the space and at the end of the base print()
|
||||
DISPLAY = "/display/"
|
||||
PAGE = "/pages/"
|
||||
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
@@ -72,10 +85,13 @@ def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, st
|
||||
+ parsed_url.path.split(DISPLAY)[0]
|
||||
)
|
||||
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
|
||||
return wiki_base, space
|
||||
page_id = ""
|
||||
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
|
||||
page_id = content[1]
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
|
||||
is_confluence_cloud = (
|
||||
".atlassian.net/wiki/spaces/" in wiki_url
|
||||
or ".jira.com/wiki/spaces/" in wiki_url
|
||||
@@ -83,15 +99,19 @@ def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
|
||||
|
||||
try:
|
||||
if is_confluence_cloud:
|
||||
wiki_base, space = _extract_confluence_keys_from_cloud_url(wiki_url)
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(
|
||||
wiki_url
|
||||
)
|
||||
else:
|
||||
wiki_base, space = _extract_confluence_keys_from_datacenter_url(wiki_url)
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base and space names. Exception: {e}"
|
||||
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return wiki_base, space, is_confluence_cloud
|
||||
return wiki_base, space, page_id, is_confluence_cloud
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@@ -147,6 +167,24 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def get_used_attachments(text: str, confluence_client: Confluence) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
Returns:
|
||||
list[str]: List of filename currently in used
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
|
||||
|
||||
def _comment_dfs(
|
||||
comments_str: str,
|
||||
comment_pages: Collection[dict[str, Any]],
|
||||
@@ -174,10 +212,137 @@ def _comment_dfs(
|
||||
return comments_str
|
||||
|
||||
|
||||
class RecursiveIndexer:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
confluence_client: Confluence,
|
||||
index_recursively: bool,
|
||||
origin_page_id: str,
|
||||
) -> None:
|
||||
self.batch_size = 1
|
||||
# batch_size
|
||||
self.confluence_client = confluence_client
|
||||
self.index_recursively = index_recursively
|
||||
self.origin_page_id = origin_page_id
|
||||
self.pages = self.recurse_children_pages(0, self.origin_page_id)
|
||||
|
||||
def get_origin_page(self) -> list[dict[str, Any]]:
|
||||
return [self._fetch_origin_page()]
|
||||
|
||||
def get_pages(self, ind: int, size: int) -> list[dict]:
|
||||
if ind * size > len(self.pages):
|
||||
return []
|
||||
return self.pages[ind * size : (ind + 1) * size]
|
||||
|
||||
def _fetch_origin_page(
|
||||
self,
|
||||
) -> dict[str, Any]:
|
||||
get_page_by_id = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_by_id
|
||||
)
|
||||
try:
|
||||
origin_page = get_page_by_id(
|
||||
self.origin_page_id, expand="body.storage.value,version"
|
||||
)
|
||||
return origin_page
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Appending orgin page with id {self.origin_page_id} failed: {e}"
|
||||
)
|
||||
return {}
|
||||
|
||||
def recurse_children_pages(
|
||||
self,
|
||||
start_ind: int,
|
||||
page_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
pages: list[dict[str, Any]] = []
|
||||
current_level_pages: list[dict[str, Any]] = []
|
||||
next_level_pages: list[dict[str, Any]] = []
|
||||
|
||||
# Initial fetch of first level children
|
||||
index = start_ind
|
||||
while batch := self._fetch_single_depth_child_pages(
|
||||
index, self.batch_size, page_id
|
||||
):
|
||||
current_level_pages.extend(batch)
|
||||
index += len(batch)
|
||||
|
||||
pages.extend(current_level_pages)
|
||||
|
||||
# Recursively index children and children's children, etc.
|
||||
while current_level_pages:
|
||||
for child in current_level_pages:
|
||||
child_index = 0
|
||||
while child_batch := self._fetch_single_depth_child_pages(
|
||||
child_index, self.batch_size, child["id"]
|
||||
):
|
||||
next_level_pages.extend(child_batch)
|
||||
child_index += len(child_batch)
|
||||
|
||||
pages.extend(next_level_pages)
|
||||
current_level_pages = next_level_pages
|
||||
next_level_pages = []
|
||||
|
||||
try:
|
||||
origin_page = self._fetch_origin_page()
|
||||
pages.append(origin_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
|
||||
|
||||
return pages
|
||||
|
||||
def _fetch_single_depth_child_pages(
|
||||
self, start_ind: int, batch_size: int, page_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
child_pages: list[dict[str, Any]] = []
|
||||
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
try:
|
||||
child_page = get_page_child_by_type(
|
||||
page_id,
|
||||
type="page",
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
|
||||
child_pages.extend(child_page)
|
||||
return child_pages
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Batch failed with page {page_id} at offset {start_ind} "
|
||||
f"with size {batch_size}, processing pages individually..."
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
ind = start_ind + i
|
||||
try:
|
||||
child_page = get_page_child_by_type(
|
||||
page_id,
|
||||
type="page",
|
||||
start=ind,
|
||||
limit=1,
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
child_pages.extend(child_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Page {page_id} at offset {ind} failed: {e}")
|
||||
raise e
|
||||
|
||||
return child_pages
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_page_url: str,
|
||||
index_recursively: bool = True,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
# if a page has one of the labels specified in this list, we will just
|
||||
@@ -188,11 +353,27 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.wiki_base, self.space, self.is_cloud = extract_confluence_keys_from_url(
|
||||
wiki_page_url
|
||||
)
|
||||
self.recursive_indexer: RecursiveIndexer | None = None
|
||||
self.index_recursively = index_recursively
|
||||
(
|
||||
self.wiki_base,
|
||||
self.space,
|
||||
self.page_id,
|
||||
self.is_cloud,
|
||||
) = extract_confluence_keys_from_url(wiki_page_url)
|
||||
|
||||
self.space_level_scan = False
|
||||
|
||||
self.confluence_client: Confluence | None = None
|
||||
|
||||
if self.page_id is None or self.page_id == "":
|
||||
self.space_level_scan = True
|
||||
|
||||
logger.info(
|
||||
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
|
||||
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}"
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
username = credentials["confluence_username"]
|
||||
access_token = credentials["confluence_access_token"]
|
||||
@@ -210,8 +391,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self,
|
||||
confluence_client: Confluence,
|
||||
start_ind: int,
|
||||
) -> Collection[dict[str, Any]]:
|
||||
def _fetch(start_ind: int, batch_size: int) -> Collection[dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_all_pages_from_space
|
||||
)
|
||||
@@ -220,9 +401,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.space,
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
status="current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None,
|
||||
status=(
|
||||
"current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
except Exception:
|
||||
@@ -241,9 +424,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.space,
|
||||
start=start_ind + i,
|
||||
limit=1,
|
||||
status="current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None,
|
||||
status=(
|
||||
"current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
)
|
||||
@@ -264,17 +449,44 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
return view_pages
|
||||
|
||||
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
if self.recursive_indexer is None:
|
||||
self.recursive_indexer = RecursiveIndexer(
|
||||
origin_page_id=self.page_id,
|
||||
batch_size=self.batch_size,
|
||||
confluence_client=self.confluence_client,
|
||||
index_recursively=self.index_recursively,
|
||||
)
|
||||
|
||||
if self.index_recursively:
|
||||
return self.recursive_indexer.get_pages(start_ind, batch_size)
|
||||
else:
|
||||
return self.recursive_indexer.get_origin_page()
|
||||
|
||||
pages: list[dict[str, Any]] = []
|
||||
|
||||
try:
|
||||
return _fetch(start_ind, self.batch_size)
|
||||
pages = (
|
||||
_fetch_space(start_ind, self.batch_size)
|
||||
if self.space_level_scan
|
||||
else _fetch_page(start_ind, self.batch_size)
|
||||
)
|
||||
return pages
|
||||
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
# error checking phase, only reachable if `self.continue_on_failure=True`
|
||||
pages: list[dict[str, Any]] = []
|
||||
for i in range(self.batch_size):
|
||||
try:
|
||||
pages.extend(_fetch(start_ind + i, 1))
|
||||
pages = (
|
||||
_fetch_space(start_ind, self.batch_size)
|
||||
if self.space_level_scan
|
||||
else _fetch_page(start_ind, self.batch_size)
|
||||
)
|
||||
return pages
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Ran into exception when fetching pages from Confluence"
|
||||
@@ -286,6 +498,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
try:
|
||||
comment_pages = cast(
|
||||
Collection[dict[str, Any]],
|
||||
@@ -321,6 +534,50 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
logger.exception("Ran into exception when fetching labels from Confluence")
|
||||
return []
|
||||
|
||||
def _fetch_attachments(
|
||||
self, confluence_client: Confluence, page_id: str, files_in_used: list[str]
|
||||
) -> str:
|
||||
get_attachments_from_content = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_attachments_from_content
|
||||
)
|
||||
files_attachment_content: list = []
|
||||
|
||||
try:
|
||||
attachments_container = get_attachments_from_content(
|
||||
page_id, start=0, limit=500
|
||||
)
|
||||
for attachment in attachments_container["results"]:
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
continue
|
||||
|
||||
if attachment["title"] not in files_in_used:
|
||||
continue
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
response = confluence_client._session.get(download_link)
|
||||
|
||||
if response.status_code == 200:
|
||||
extract = extract_file_text(
|
||||
attachment["title"], io.BytesIO(response.content), False
|
||||
)
|
||||
files_attachment_content.append(extract)
|
||||
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
logger.exception(
|
||||
f"Ran into exception when fetching attachments from Confluence: {e}"
|
||||
)
|
||||
|
||||
return "\n".join(files_attachment_content)
|
||||
|
||||
def _get_doc_batch(
|
||||
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> tuple[list[Document], int]:
|
||||
@@ -328,8 +585,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
batch = self._fetch_pages(self.confluence_client, start_ind)
|
||||
|
||||
for page in batch:
|
||||
last_modified_str = page["version"]["when"]
|
||||
author = cast(str | None, page["version"].get("by", {}).get("email"))
|
||||
@@ -345,15 +602,18 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if time_filter is None or time_filter(last_modified):
|
||||
page_id = page["id"]
|
||||
|
||||
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
|
||||
# check disallowed labels
|
||||
if self.labels_to_skip:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
label_intersection = self.labels_to_skip.intersection(page_labels)
|
||||
if label_intersection:
|
||||
logger.info(
|
||||
f"Page with ID '{page_id}' has a label which has been "
|
||||
f"designated as disallowed: {label_intersection}. Skipping."
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
page_html = (
|
||||
@@ -366,8 +626,19 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
|
||||
files_in_used = get_used_attachments(page_html, self.confluence_client)
|
||||
attachment_text = self._fetch_attachments(
|
||||
self.confluence_client, page_id, files_in_used
|
||||
)
|
||||
page_text += attachment_text
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
doc_metadata: dict[str, str | list[str]] = {
|
||||
"Wiki Space Name": self.space
|
||||
}
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
@@ -376,12 +647,10 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=[BasicExpertInfo(email=author)]
|
||||
if author
|
||||
else None,
|
||||
metadata={
|
||||
"Wiki Space Name": self.space,
|
||||
},
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author)] if author else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
)
|
||||
return doc_batch, len(batch)
|
||||
@@ -423,8 +692,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"])
|
||||
connector.load_credentials(
|
||||
{
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from requests import HTTPError
|
||||
from retry import retry
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
@@ -18,23 +22,38 @@ class ConfluenceRateLimitError(Exception):
|
||||
|
||||
|
||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
@retry(
|
||||
exceptions=ConfluenceRateLimitError,
|
||||
tries=10,
|
||||
delay=1,
|
||||
max_delay=600, # 10 minutes
|
||||
backoff=2,
|
||||
jitter=1,
|
||||
)
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
try:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
raise ConfluenceRateLimitError()
|
||||
raise
|
||||
starting_delay = 5
|
||||
backoff = 2
|
||||
max_delay = 600
|
||||
|
||||
for attempt in range(10):
|
||||
try:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
retry_after = None
|
||||
try:
|
||||
retry_after = int(e.response.headers.get("Retry-After"))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if retry_after:
|
||||
logger.warning(
|
||||
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limit hit. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# re-raise, let caller handle
|
||||
raise
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TypeVar
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.text_processing import is_valid_email
|
||||
|
||||
@@ -57,3 +58,7 @@ def process_in_batches(
|
||||
) -> Iterator[list[U]]:
|
||||
for i in range(0, len(objects), batch_size):
|
||||
yield [process_function(obj) for obj in objects[i : i + batch_size]]
|
||||
|
||||
|
||||
def get_metadata_keys_to_ignore() -> list[str]:
|
||||
return [IGNORE_FOR_QA]
|
||||
|
||||
@@ -56,6 +56,16 @@ def extract_text_from_content(content: dict) -> str:
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
@@ -117,8 +127,10 @@ def fetch_jira_issues_batch(
|
||||
continue
|
||||
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments]
|
||||
semantic_rep = (
|
||||
f"{jira.fields.description}\n"
|
||||
if jira.fields.description
|
||||
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
@@ -147,14 +159,18 @@ def fetch_jira_issues_batch(
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if jira.fields.priority:
|
||||
metadata_dict["priority"] = jira.fields.priority.name
|
||||
if jira.fields.status:
|
||||
metadata_dict["status"] = jira.fields.status.name
|
||||
if jira.fields.resolution:
|
||||
metadata_dict["resolution"] = jira.fields.resolution.name
|
||||
if jira.fields.labels:
|
||||
metadata_dict["label"] = jira.fields.labels
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
|
||||
@@ -11,6 +11,9 @@ from requests import Response
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -58,63 +61,36 @@ class DiscourseConnector(PollConnector):
|
||||
self.category_id_map: dict[int, str] = {}
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.permissions: DiscoursePerms | None = None
|
||||
self.active_categories: set | None = None
|
||||
|
||||
@rate_limit_builder(max_calls=50, period=60)
|
||||
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
|
||||
if not self.permissions:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
return discourse_request(endpoint, self.permissions, params)
|
||||
|
||||
def _get_categories_map(
|
||||
self,
|
||||
) -> None:
|
||||
assert self.permissions is not None
|
||||
categories_endpoint = urllib.parse.urljoin(self.base_url, "categories.json")
|
||||
response = discourse_request(
|
||||
response = self._make_request(
|
||||
endpoint=categories_endpoint,
|
||||
perms=self.permissions,
|
||||
params={"include_subcategories": True},
|
||||
)
|
||||
categories = response.json()["category_list"]["categories"]
|
||||
|
||||
self.category_id_map = {
|
||||
category["id"]: category["name"]
|
||||
for category in categories
|
||||
if not self.categories or category["name"].lower() in self.categories
|
||||
cat["id"]: cat["name"]
|
||||
for cat in categories
|
||||
if not self.categories or cat["name"].lower() in self.categories
|
||||
}
|
||||
|
||||
def _get_latest_topics(
|
||||
self, start: datetime | None, end: datetime | None
|
||||
) -> list[int]:
|
||||
assert self.permissions is not None
|
||||
topic_ids = []
|
||||
|
||||
valid_categories = set(self.category_id_map.keys())
|
||||
|
||||
latest_endpoint = urllib.parse.urljoin(self.base_url, "latest.json")
|
||||
response = discourse_request(endpoint=latest_endpoint, perms=self.permissions)
|
||||
topics = response.json()["topic_list"]["topics"]
|
||||
for topic in topics:
|
||||
last_time = topic.get("last_posted_at")
|
||||
if not last_time:
|
||||
continue
|
||||
last_time_dt = time_str_to_utc(last_time)
|
||||
|
||||
if start and start > last_time_dt:
|
||||
continue
|
||||
if end and end < last_time_dt:
|
||||
continue
|
||||
|
||||
if valid_categories and topic.get("category_id") not in valid_categories:
|
||||
continue
|
||||
|
||||
topic_ids.append(topic["id"])
|
||||
|
||||
return topic_ids
|
||||
self.active_categories = set(self.category_id_map)
|
||||
|
||||
def _get_doc_from_topic(self, topic_id: int) -> Document:
|
||||
assert self.permissions is not None
|
||||
topic_endpoint = urllib.parse.urljoin(self.base_url, f"t/{topic_id}.json")
|
||||
response = discourse_request(
|
||||
endpoint=topic_endpoint,
|
||||
perms=self.permissions,
|
||||
)
|
||||
response = self._make_request(endpoint=topic_endpoint)
|
||||
topic = response.json()
|
||||
|
||||
topic_url = urllib.parse.urljoin(self.base_url, f"t/{topic['slug']}")
|
||||
@@ -138,10 +114,16 @@ class DiscourseConnector(PollConnector):
|
||||
sections.append(
|
||||
Section(link=topic_url, text=parse_html_page_basic(post["cooked"]))
|
||||
)
|
||||
category_name = self.category_id_map.get(topic["category_id"])
|
||||
|
||||
metadata: dict[str, str | list[str]] = (
|
||||
{
|
||||
"category": category_name,
|
||||
}
|
||||
if category_name
|
||||
else {}
|
||||
)
|
||||
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"category": self.category_id_map[topic["category_id"]],
|
||||
}
|
||||
if topic.get("tags"):
|
||||
metadata["tags"] = topic["tags"]
|
||||
|
||||
@@ -157,26 +139,78 @@ class DiscourseConnector(PollConnector):
|
||||
)
|
||||
return doc
|
||||
|
||||
def _get_latest_topics(
|
||||
self, start: datetime | None, end: datetime | None, page: int
|
||||
) -> list[int]:
|
||||
assert self.permissions is not None
|
||||
topic_ids = []
|
||||
|
||||
if not self.categories:
|
||||
latest_endpoint = urllib.parse.urljoin(
|
||||
self.base_url, f"latest.json?page={page}"
|
||||
)
|
||||
response = self._make_request(endpoint=latest_endpoint)
|
||||
topics = response.json()["topic_list"]["topics"]
|
||||
|
||||
else:
|
||||
topics = []
|
||||
empty_categories = []
|
||||
|
||||
for category_id in self.category_id_map.keys():
|
||||
category_endpoint = urllib.parse.urljoin(
|
||||
self.base_url, f"c/{category_id}.json?page={page}&sys=latest"
|
||||
)
|
||||
response = self._make_request(endpoint=category_endpoint)
|
||||
new_topics = response.json()["topic_list"]["topics"]
|
||||
|
||||
if len(new_topics) == 0:
|
||||
empty_categories.append(category_id)
|
||||
topics.extend(new_topics)
|
||||
|
||||
for empty_category in empty_categories:
|
||||
self.category_id_map.pop(empty_category)
|
||||
|
||||
for topic in topics:
|
||||
last_time = topic.get("last_posted_at")
|
||||
if not last_time:
|
||||
continue
|
||||
|
||||
last_time_dt = time_str_to_utc(last_time)
|
||||
if (start and start > last_time_dt) or (end and end < last_time_dt):
|
||||
continue
|
||||
|
||||
topic_ids.append(topic["id"])
|
||||
if len(topic_ids) >= self.batch_size:
|
||||
break
|
||||
|
||||
return topic_ids
|
||||
|
||||
def _yield_discourse_documents(
|
||||
self, topic_ids: list[int]
|
||||
self,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
for topic_id in topic_ids:
|
||||
doc_batch.append(self._get_doc_from_topic(topic_id))
|
||||
page = 1
|
||||
while topic_ids := self._get_latest_topics(start, end, page):
|
||||
doc_batch: list[Document] = []
|
||||
for topic_id in topic_ids:
|
||||
doc_batch.append(self._get_doc_from_topic(topic_id))
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
page += 1
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
def load_credentials(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
self.permissions = DiscoursePerms(
|
||||
api_key=credentials["discourse_api_key"],
|
||||
api_username=credentials["discourse_api_username"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def poll_source(
|
||||
@@ -184,16 +218,13 @@ class DiscourseConnector(PollConnector):
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.permissions is None:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
|
||||
start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc)
|
||||
end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc)
|
||||
|
||||
self._get_categories_map()
|
||||
|
||||
latest_topic_ids = self._get_latest_topics(
|
||||
start=start_datetime, end=end_datetime
|
||||
)
|
||||
|
||||
return self._yield_discourse_documents(latest_topic_ids)
|
||||
yield from self._yield_discourse_documents(start_datetime, end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -209,7 +240,5 @@ if __name__ == "__main__":
|
||||
|
||||
current = time.time()
|
||||
one_year_ago = current - 24 * 60 * 60 * 360
|
||||
|
||||
latest_docs = connector.poll_source(one_year_ago, current)
|
||||
|
||||
print(next(latest_docs))
|
||||
|
||||
0
backend/danswer/connectors/dropbox/__init__.py
Normal file
0
backend/danswer/connectors/dropbox/__init__.py
Normal file
155
backend/danswer/connectors/dropbox/connector.py
Normal file
155
backend/danswer/connectors/dropbox/connector.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from datetime import timezone
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from dropbox import Dropbox # type: ignore
|
||||
from dropbox.exceptions import ApiError # type:ignore
|
||||
from dropbox.files import FileMetadata # type:ignore
|
||||
from dropbox.files import FolderMetadata # type:ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DropboxConnector(LoadConnector, PollConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.dropbox_client: Dropbox | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.dropbox_client = Dropbox(credentials["dropbox_access_token"])
|
||||
return None
|
||||
|
||||
def _download_file(self, path: str) -> bytes:
|
||||
"""Download a single file from Dropbox."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
_, resp = self.dropbox_client.files_download(path)
|
||||
return resp.content
|
||||
|
||||
def _get_shared_link(self, path: str) -> str:
|
||||
"""Create a shared link for a file in Dropbox."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
try:
|
||||
# Check if a shared link already exists
|
||||
shared_links = self.dropbox_client.sharing_list_shared_links(path=path)
|
||||
if shared_links.links:
|
||||
return shared_links.links[0].url
|
||||
|
||||
link_metadata = (
|
||||
self.dropbox_client.sharing_create_shared_link_with_settings(path)
|
||||
)
|
||||
return link_metadata.url
|
||||
except ApiError as err:
|
||||
logger.exception(f"Failed to create a shared link for {path}: {err}")
|
||||
return ""
|
||||
|
||||
def _yield_files_recursive(
|
||||
self,
|
||||
path: str,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Yield files in batches from a specified Dropbox folder, including subfolders."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
result = self.dropbox_client.files_list_folder(
|
||||
path,
|
||||
limit=self.batch_size,
|
||||
recursive=False,
|
||||
include_non_downloadable_files=False,
|
||||
)
|
||||
|
||||
while True:
|
||||
batch: list[Document] = []
|
||||
for entry in result.entries:
|
||||
if isinstance(entry, FileMetadata):
|
||||
modified_time = entry.client_modified
|
||||
if modified_time.tzinfo is None:
|
||||
# If no timezone info, assume it is UTC
|
||||
modified_time = modified_time.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# If not in UTC, translate it
|
||||
modified_time = modified_time.astimezone(timezone.utc)
|
||||
|
||||
time_as_seconds = int(modified_time.timestamp())
|
||||
if start and time_as_seconds < start:
|
||||
continue
|
||||
if end and time_as_seconds > end:
|
||||
continue
|
||||
|
||||
downloaded_file = self._download_file(entry.path_display)
|
||||
link = self._get_shared_link(entry.path_display)
|
||||
try:
|
||||
text = extract_file_text(
|
||||
entry.name,
|
||||
BytesIO(downloaded_file),
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"doc:{entry.id}",
|
||||
sections=[Section(link=link, text=text)],
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=entry.name,
|
||||
doc_updated_at=modified_time,
|
||||
metadata={"type": "article"},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error decoding file {entry.path_display} as utf-8 error occurred: {e}"
|
||||
)
|
||||
|
||||
elif isinstance(entry, FolderMetadata):
|
||||
yield from self._yield_files_recursive(entry.path_lower, start, end)
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
if not result.has_more:
|
||||
break
|
||||
|
||||
result = self.dropbox_client.files_list_folder_continue(result.cursor)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
for batch in self._yield_files_recursive("", start, end):
|
||||
yield batch
|
||||
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = DropboxConnector()
|
||||
connector.load_credentials(
|
||||
{
|
||||
"dropbox_access_token": os.environ["DROPBOX_ACCESS_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
@@ -1,13 +1,18 @@
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.axero.connector import AxeroConnector
|
||||
from danswer.connectors.blob.connector import BlobStorageConnector
|
||||
from danswer.connectors.bookstack.connector import BookstackConnector
|
||||
from danswer.connectors.clickup.connector import ClickupConnector
|
||||
from danswer.connectors.confluence.connector import ConfluenceConnector
|
||||
from danswer.connectors.danswer_jira.connector import JiraConnector
|
||||
from danswer.connectors.discourse.connector import DiscourseConnector
|
||||
from danswer.connectors.document360.connector import Document360Connector
|
||||
from danswer.connectors.dropbox.connector import DropboxConnector
|
||||
from danswer.connectors.file.connector import LocalFileConnector
|
||||
from danswer.connectors.github.connector import GithubConnector
|
||||
from danswer.connectors.gitlab.connector import GitlabConnector
|
||||
@@ -23,17 +28,23 @@ from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.linear.connector import LinearConnector
|
||||
from danswer.connectors.loopio.connector import LoopioConnector
|
||||
from danswer.connectors.mediawiki.wiki import MediaWikiConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.connectors.notion.connector import NotionConnector
|
||||
from danswer.connectors.productboard.connector import ProductboardConnector
|
||||
from danswer.connectors.requesttracker.connector import RequestTrackerConnector
|
||||
from danswer.connectors.salesforce.connector import SalesforceConnector
|
||||
from danswer.connectors.sharepoint.connector import SharepointConnector
|
||||
from danswer.connectors.slab.connector import SlabConnector
|
||||
from danswer.connectors.slack.connector import SlackPollConnector
|
||||
from danswer.connectors.slack.load_connector import SlackLoadConnector
|
||||
from danswer.connectors.teams.connector import TeamsConnector
|
||||
from danswer.connectors.web.connector import WebConnector
|
||||
from danswer.connectors.wikipedia.connector import WikipediaConnector
|
||||
from danswer.connectors.zendesk.connector import ZendeskConnector
|
||||
from danswer.connectors.zulip.connector import ZulipConnector
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.models import Credential
|
||||
|
||||
|
||||
class ConnectorMissingException(Exception):
|
||||
@@ -71,9 +82,19 @@ def identify_connector_class(
|
||||
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
|
||||
DocumentSource.ZENDESK: ZendeskConnector,
|
||||
DocumentSource.LOOPIO: LoopioConnector,
|
||||
DocumentSource.DROPBOX: DropboxConnector,
|
||||
DocumentSource.SHAREPOINT: SharepointConnector,
|
||||
DocumentSource.TEAMS: TeamsConnector,
|
||||
DocumentSource.SALESFORCE: SalesforceConnector,
|
||||
DocumentSource.DISCOURSE: DiscourseConnector,
|
||||
DocumentSource.AXERO: AxeroConnector,
|
||||
DocumentSource.CLICKUP: ClickupConnector,
|
||||
DocumentSource.MEDIAWIKI: MediaWikiConnector,
|
||||
DocumentSource.WIKIPEDIA: WikipediaConnector,
|
||||
DocumentSource.S3: BlobStorageConnector,
|
||||
DocumentSource.R2: BlobStorageConnector,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
@@ -99,7 +120,6 @@ def identify_connector_class(
|
||||
raise ConnectorMissingException(
|
||||
f"Connector for source={source} does not accept input_type={input_type}"
|
||||
)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
@@ -107,10 +127,14 @@ def instantiate_connector(
|
||||
source: DocumentSource,
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credentials: dict[str, Any],
|
||||
) -> tuple[BaseConnector, dict[str, Any] | None]:
|
||||
credential: Credential,
|
||||
db_session: Session,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credentials)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
return connector, new_credentials
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
return connector
|
||||
|
||||
@@ -69,7 +69,9 @@ def _process_file(
|
||||
|
||||
if is_text_file_extension(file_name):
|
||||
encoding = detect_encoding(file)
|
||||
file_content_raw, file_metadata = read_text_file(file, encoding=encoding)
|
||||
file_content_raw, file_metadata = read_text_file(
|
||||
file, encoding=encoding, ignore_danswer_metadata=False
|
||||
)
|
||||
|
||||
# Using the PDF reader function directly to pass in password cleanly
|
||||
elif extension == ".pdf":
|
||||
@@ -83,8 +85,18 @@ def _process_file(
|
||||
|
||||
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
|
||||
|
||||
# add a prefix to avoid conflicts with other connectors
|
||||
doc_id = f"FILE_CONNECTOR__{file_name}"
|
||||
if metadata:
|
||||
doc_id = metadata.get("document_id") or doc_id
|
||||
|
||||
# If this is set, we will show this in the UI as the "name" of the file
|
||||
file_display_name_override = all_metadata.get("file_display_name")
|
||||
file_display_name = all_metadata.get("file_display_name") or os.path.basename(
|
||||
file_name
|
||||
)
|
||||
title = (
|
||||
all_metadata["title"] or "" if "title" in all_metadata else file_display_name
|
||||
)
|
||||
|
||||
time_updated = all_metadata.get("time_updated", datetime.now(timezone.utc))
|
||||
if isinstance(time_updated, str):
|
||||
@@ -99,6 +111,7 @@ def _process_file(
|
||||
for k, v in all_metadata.items()
|
||||
if k
|
||||
not in [
|
||||
"document_id",
|
||||
"time_updated",
|
||||
"doc_updated_at",
|
||||
"link",
|
||||
@@ -106,6 +119,7 @@ def _process_file(
|
||||
"secondary_owners",
|
||||
"filename",
|
||||
"file_display_name",
|
||||
"title",
|
||||
]
|
||||
}
|
||||
|
||||
@@ -124,13 +138,13 @@ def _process_file(
|
||||
|
||||
return [
|
||||
Document(
|
||||
id=f"FILE_CONNECTOR__{file_name}", # add a prefix to avoid conflicts with other connectors
|
||||
id=doc_id,
|
||||
sections=[
|
||||
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
|
||||
],
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier=file_display_name_override
|
||||
or os.path.basename(file_name),
|
||||
semantic_identifier=file_display_name,
|
||||
title=title,
|
||||
doc_updated_at=final_time_updated,
|
||||
primary_owners=p_owners,
|
||||
secondary_owners=s_owners,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import fnmatch
|
||||
import itertools
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
@@ -6,7 +8,10 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import gitlab
|
||||
import pytz
|
||||
from gitlab.v4.objects import Project
|
||||
|
||||
from danswer.configs.app_configs import GITLAB_CONNECTOR_INCLUDE_CODE_FILES
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@@ -19,7 +24,13 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
# List of directories/Files to exclude
|
||||
exclude_patterns = [
|
||||
"logs",
|
||||
".github/",
|
||||
".gitlab/",
|
||||
".pre-commit-config.yaml",
|
||||
]
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -72,6 +83,37 @@ def _convert_issue_to_document(issue: Any) -> Document:
|
||||
return doc
|
||||
|
||||
|
||||
def _convert_code_to_document(
|
||||
project: Project, file: Any, url: str, projectName: str, projectOwner: str
|
||||
) -> Document:
|
||||
file_content_obj = project.files.get(
|
||||
file_path=file["path"], ref="master"
|
||||
) # Replace 'master' with your branch name if needed
|
||||
try:
|
||||
file_content = file_content_obj.decode().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
file_content = file_content_obj.decode().decode("latin-1")
|
||||
|
||||
file_url = f"{url}/{projectOwner}/{projectName}/-/blob/master/{file['path']}" # Construct the file URL
|
||||
doc = Document(
|
||||
id=file["id"],
|
||||
sections=[Section(link=file_url, text=file_content)],
|
||||
source=DocumentSource.GITLAB,
|
||||
semantic_identifier=file["name"],
|
||||
doc_updated_at=datetime.now().replace(
|
||||
tzinfo=timezone.utc
|
||||
), # Use current time as updated_at
|
||||
primary_owners=[], # Fill this as needed
|
||||
metadata={"type": "CodeFile"},
|
||||
)
|
||||
return doc
|
||||
|
||||
|
||||
def _should_exclude(path: str) -> bool:
|
||||
"""Check if a path matches any of the exclude patterns."""
|
||||
return any(fnmatch.fnmatch(path, pattern) for pattern in exclude_patterns)
|
||||
|
||||
|
||||
class GitlabConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -81,6 +123,7 @@ class GitlabConnector(LoadConnector, PollConnector):
|
||||
state_filter: str = "all",
|
||||
include_mrs: bool = True,
|
||||
include_issues: bool = True,
|
||||
include_code_files: bool = GITLAB_CONNECTOR_INCLUDE_CODE_FILES,
|
||||
) -> None:
|
||||
self.project_owner = project_owner
|
||||
self.project_name = project_name
|
||||
@@ -88,6 +131,7 @@ class GitlabConnector(LoadConnector, PollConnector):
|
||||
self.state_filter = state_filter
|
||||
self.include_mrs = include_mrs
|
||||
self.include_issues = include_issues
|
||||
self.include_code_files = include_code_files
|
||||
self.gitlab_client: gitlab.Gitlab | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -101,45 +145,80 @@ class GitlabConnector(LoadConnector, PollConnector):
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.gitlab_client is None:
|
||||
raise ConnectorMissingCredentialError("Gitlab")
|
||||
project = self.gitlab_client.projects.get(
|
||||
project: gitlab.Project = self.gitlab_client.projects.get(
|
||||
f"{self.project_owner}/{self.project_name}"
|
||||
)
|
||||
|
||||
# Fetch code files
|
||||
if self.include_code_files:
|
||||
# Fetching using BFS as project.report_tree with recursion causing slow load
|
||||
queue = deque([""]) # Start with the root directory
|
||||
while queue:
|
||||
current_path = queue.popleft()
|
||||
files = project.repository_tree(path=current_path, all=True)
|
||||
for file_batch in _batch_gitlab_objects(files, self.batch_size):
|
||||
code_doc_batch: list[Document] = []
|
||||
for file in file_batch:
|
||||
if _should_exclude(file["path"]):
|
||||
continue
|
||||
|
||||
if file["type"] == "blob":
|
||||
code_doc_batch.append(
|
||||
_convert_code_to_document(
|
||||
project,
|
||||
file,
|
||||
self.gitlab_client.url,
|
||||
self.project_name,
|
||||
self.project_owner,
|
||||
)
|
||||
)
|
||||
elif file["type"] == "tree":
|
||||
queue.append(file["path"])
|
||||
|
||||
if code_doc_batch:
|
||||
yield code_doc_batch
|
||||
|
||||
if self.include_mrs:
|
||||
merge_requests = project.mergerequests.list(
|
||||
state=self.state_filter, order_by="updated_at", sort="desc"
|
||||
)
|
||||
|
||||
for mr_batch in _batch_gitlab_objects(merge_requests, self.batch_size):
|
||||
doc_batch: list[Document] = []
|
||||
mr_doc_batch: list[Document] = []
|
||||
for mr in mr_batch:
|
||||
mr.updated_at = datetime.strptime(
|
||||
mr.updated_at, "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
mr.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z"
|
||||
)
|
||||
if start is not None and mr.updated_at < start:
|
||||
yield doc_batch
|
||||
if start is not None and mr.updated_at < start.replace(
|
||||
tzinfo=pytz.UTC
|
||||
):
|
||||
yield mr_doc_batch
|
||||
return
|
||||
if end is not None and mr.updated_at > end:
|
||||
if end is not None and mr.updated_at > end.replace(tzinfo=pytz.UTC):
|
||||
continue
|
||||
doc_batch.append(_convert_merge_request_to_document(mr))
|
||||
yield doc_batch
|
||||
mr_doc_batch.append(_convert_merge_request_to_document(mr))
|
||||
yield mr_doc_batch
|
||||
|
||||
if self.include_issues:
|
||||
issues = project.issues.list(state=self.state_filter)
|
||||
|
||||
for issue_batch in _batch_gitlab_objects(issues, self.batch_size):
|
||||
doc_batch = []
|
||||
issue_doc_batch: list[Document] = []
|
||||
for issue in issue_batch:
|
||||
issue.updated_at = datetime.strptime(
|
||||
issue.updated_at, "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
issue.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z"
|
||||
)
|
||||
if start is not None and issue.updated_at < start:
|
||||
yield doc_batch
|
||||
return
|
||||
if end is not None and issue.updated_at > end:
|
||||
continue
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield doc_batch
|
||||
if start is not None:
|
||||
start = start.replace(tzinfo=pytz.UTC)
|
||||
if issue.updated_at < start:
|
||||
yield issue_doc_batch
|
||||
return
|
||||
if end is not None:
|
||||
end = end.replace(tzinfo=pytz.UTC)
|
||||
if issue.updated_at > end:
|
||||
continue
|
||||
issue_doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield issue_doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_gitlab()
|
||||
@@ -163,11 +242,12 @@ if __name__ == "__main__":
|
||||
state_filter="all",
|
||||
include_mrs=True,
|
||||
include_issues=True,
|
||||
include_code_files=GITLAB_CONNECTOR_INCLUDE_CODE_FILES,
|
||||
)
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"github_access_token": os.environ["GITLAB_ACCESS_TOKEN"],
|
||||
"gitlab_access_token": os.environ["GITLAB_ACCESS_TOKEN"],
|
||||
"gitlab_url": os.environ["GITLAB_URL"],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
|
||||
from google.auth.credentials import Credentials # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient import discovery # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -36,7 +37,7 @@ logger = setup_logger()
|
||||
class GmailConnector(LoadConnector, PollConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.creds: Credentials | None = None
|
||||
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
"""Checks for two different types of credentials.
|
||||
@@ -45,7 +46,7 @@ class GmailConnector(LoadConnector, PollConnector):
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
creds = None
|
||||
creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(
|
||||
@@ -74,7 +75,7 @@ class GmailConnector(LoadConnector, PollConnector):
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
creds = creds.with_subject(delegated_user_email) if creds else None
|
||||
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
|
||||
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
|
||||
@@ -11,6 +11,7 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.gmail.constants import CRED_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
@@ -146,6 +147,7 @@ def build_service_account_creds(
|
||||
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
|
||||
|
||||
return CredentialBase(
|
||||
source=DocumentSource.GMAIL,
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
)
|
||||
|
||||
@@ -198,7 +198,10 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"Indexing Gong call from {call_time_str.split('T', 1)[0]}: {call_title}"
|
||||
)
|
||||
|
||||
call_parties = call_details["parties"]
|
||||
call_parties = cast(list[dict] | None, call_details.get("parties"))
|
||||
if call_parties is None:
|
||||
logger.error(f"Couldn't get parties for Call ID: {call_id}")
|
||||
call_parties = []
|
||||
|
||||
id_to_name_map = self._parse_parties(call_parties)
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ from itertools import chain
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from google.auth.credentials import Credentials # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient import discovery # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
@@ -41,6 +42,7 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import docx_to_text
|
||||
from danswer.file_processing.extract_file_text import pdf_to_text
|
||||
from danswer.file_processing.extract_file_text import pptx_to_text
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -56,6 +58,10 @@ class GDriveMimeType(str, Enum):
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
POWERPOINT = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
@@ -324,6 +330,12 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return pdf_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PPT.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
@@ -346,7 +358,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
self.follow_shortcuts = follow_shortcuts
|
||||
self.only_org_public = only_org_public
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.creds: Credentials | None = None
|
||||
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
@staticmethod
|
||||
def _process_folder_paths(
|
||||
@@ -387,7 +399,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
creds = None
|
||||
creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(
|
||||
@@ -416,7 +428,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
creds = creds.with_subject(delegated_user_email) if creds else None
|
||||
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
|
||||
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
@@ -461,6 +473,11 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
doc_batch = []
|
||||
for file in files_batch:
|
||||
try:
|
||||
# Skip files that are shortcuts
|
||||
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
|
||||
logger.info("Ignoring Drive Shortcut Filetype")
|
||||
continue
|
||||
|
||||
if self.only_org_public:
|
||||
if "permissions" not in file:
|
||||
continue
|
||||
|
||||
@@ -11,6 +11,7 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_drive.constants import CRED_KEY
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
@@ -118,6 +119,7 @@ def update_credential_access_tokens(
|
||||
|
||||
|
||||
def build_service_account_creds(
|
||||
source: DocumentSource,
|
||||
delegated_user_email: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_service_account_key()
|
||||
@@ -131,6 +133,7 @@ def build_service_account_creds(
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -50,6 +50,12 @@ class PollConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
|
||||
0
backend/danswer/connectors/mediawiki/__init__.py
Normal file
0
backend/danswer/connectors/mediawiki/__init__.py
Normal file
166
backend/danswer/connectors/mediawiki/family.py
Normal file
166
backend/danswer/connectors/mediawiki/family.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import functools
|
||||
import itertools
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlunparse
|
||||
|
||||
from pywikibot import family # type: ignore[import-untyped]
|
||||
from pywikibot import pagegenerators # type: ignore[import-untyped]
|
||||
from pywikibot.scripts import generate_family_file # type: ignore[import-untyped]
|
||||
from pywikibot.scripts.generate_user_files import pywikibot # type: ignore[import-untyped]
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@mock.patch.object(
|
||||
builtins, "print", lambda *args: logger.info("\t".join(map(str, args)))
|
||||
)
|
||||
class FamilyFileGeneratorInMemory(generate_family_file.FamilyFileGenerator):
|
||||
"""A subclass of FamilyFileGenerator that writes the family file to memory instead of to disk."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
name: str,
|
||||
dointerwiki: str | bool = True,
|
||||
verify: str | bool = True,
|
||||
):
|
||||
"""Initialize the FamilyFileGeneratorInMemory."""
|
||||
|
||||
url_parse = urlparse(url, "https")
|
||||
if not url_parse.netloc and url_parse.path:
|
||||
url = urlunparse(
|
||||
(url_parse.scheme, url_parse.path, url_parse.netloc, *url_parse[3:])
|
||||
)
|
||||
else:
|
||||
url = urlunparse(url_parse)
|
||||
assert isinstance(url, str)
|
||||
|
||||
if any(x not in generate_family_file.NAME_CHARACTERS for x in name):
|
||||
raise ValueError(
|
||||
'ERROR: Name of family "{}" must be ASCII letters and digits [a-zA-Z0-9]',
|
||||
name,
|
||||
)
|
||||
|
||||
if isinstance(dointerwiki, bool):
|
||||
dointerwiki = "Y" if dointerwiki else "N"
|
||||
assert isinstance(dointerwiki, str)
|
||||
|
||||
if isinstance(verify, bool):
|
||||
verify = "Y" if verify else "N"
|
||||
assert isinstance(verify, str)
|
||||
|
||||
super().__init__(url, name, dointerwiki, verify)
|
||||
self.family_definition: type[family.Family] | None = None
|
||||
|
||||
def get_params(self) -> bool:
|
||||
"""Get the parameters for the family class definition.
|
||||
|
||||
This override prevents the method from prompting the user for input (which would be impossible in this context).
|
||||
We do all the input validation in the constructor.
|
||||
"""
|
||||
return True
|
||||
|
||||
def writefile(self, verify: Any) -> None:
|
||||
"""Write the family file.
|
||||
|
||||
This overrides the method in the parent class to write the family definition to memory instead of to disk.
|
||||
|
||||
Args:
|
||||
verify: unused argument necessary to match the signature of the method in the parent class.
|
||||
"""
|
||||
code_hostname_pairs = {
|
||||
f"{k}": f"{urlparse(w.server).netloc}" for k, w in self.wikis.items()
|
||||
}
|
||||
|
||||
code_path_pairs = {f"{k}": f"{w.scriptpath}" for k, w in self.wikis.items()}
|
||||
|
||||
code_protocol_pairs = {
|
||||
f"{k}": f"{urlparse(w.server).scheme}" for k, w in self.wikis.items()
|
||||
}
|
||||
|
||||
class Family(family.Family): # noqa: D101
|
||||
"""The family definition for the wiki."""
|
||||
|
||||
name = "%(name)s"
|
||||
langs = code_hostname_pairs
|
||||
|
||||
def scriptpath(self, code: str) -> str:
|
||||
return code_path_pairs[code]
|
||||
|
||||
def protocol(self, code: str) -> str:
|
||||
return code_protocol_pairs[code]
|
||||
|
||||
self.family_definition = Family
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def generate_family_class(url: str, name: str) -> type[family.Family]:
|
||||
"""Generate a family file for a given URL and name.
|
||||
|
||||
Args:
|
||||
url: The URL of the wiki.
|
||||
name: The short name of the wiki (customizable by the user).
|
||||
|
||||
Returns:
|
||||
The family definition.
|
||||
|
||||
Raises:
|
||||
ValueError: If the family definition was not generated.
|
||||
"""
|
||||
|
||||
generator = FamilyFileGeneratorInMemory(url, name, "Y", "Y")
|
||||
generator.run()
|
||||
if generator.family_definition is None:
|
||||
raise ValueError("Family definition was not generated.")
|
||||
return generator.family_definition
|
||||
|
||||
|
||||
def family_class_dispatch(url: str, name: str) -> type[family.Family]:
|
||||
"""Find or generate a family class for a given URL and name.
|
||||
|
||||
Args:
|
||||
url: The URL of the wiki.
|
||||
name: The short name of the wiki (customizable by the user).
|
||||
|
||||
"""
|
||||
if "wikipedia" in url:
|
||||
import pywikibot.families.wikipedia_family # type: ignore[import-untyped]
|
||||
|
||||
return pywikibot.families.wikipedia_family.Family
|
||||
# TODO: Support additional families pre-defined in `pywikibot.families.*_family.py` files
|
||||
return generate_family_class(url, name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
url = "fallout.fandom.com/wiki/Fallout_Wiki"
|
||||
name = "falloutfandom"
|
||||
|
||||
categories: list[str] = []
|
||||
pages = ["Fallout: New Vegas"]
|
||||
recursion_depth = 1
|
||||
family_type = generate_family_class(url, name)
|
||||
|
||||
site = pywikibot.Site(fam=family_type(), code="en")
|
||||
categories = [
|
||||
pywikibot.Category(site, f"Category:{category.replace(' ', '_')}")
|
||||
for category in categories
|
||||
]
|
||||
pages = [pywikibot.Page(site, page) for page in pages]
|
||||
all_pages = itertools.chain(
|
||||
pages,
|
||||
*[
|
||||
pagegenerators.CategorizedPageGenerator(category, recurse=recursion_depth)
|
||||
for category in categories
|
||||
],
|
||||
)
|
||||
for page in all_pages:
|
||||
print(page.title())
|
||||
print(page.text[:1000])
|
||||
220
backend/danswer/connectors/mediawiki/wiki.py
Normal file
220
backend/danswer/connectors/mediawiki/wiki.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import itertools
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import ClassVar
|
||||
|
||||
import pywikibot.time # type: ignore[import-untyped]
|
||||
from pywikibot import pagegenerators # type: ignore[import-untyped]
|
||||
from pywikibot import textlib # type: ignore[import-untyped]
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.mediawiki.family import family_class_dispatch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
|
||||
|
||||
def pywikibot_timestamp_to_utc_datetime(
|
||||
timestamp: pywikibot.time.Timestamp,
|
||||
) -> datetime.datetime:
|
||||
"""Convert a pywikibot timestamp to a datetime object in UTC.
|
||||
|
||||
Args:
|
||||
timestamp: The pywikibot timestamp to convert.
|
||||
|
||||
Returns:
|
||||
A datetime object in UTC.
|
||||
"""
|
||||
return datetime.datetime.astimezone(timestamp, tz=datetime.timezone.utc)
|
||||
|
||||
|
||||
def get_doc_from_page(
|
||||
page: pywikibot.Page, site: pywikibot.Site | None, source_type: DocumentSource
|
||||
) -> Document:
|
||||
"""Generate Danswer Document from a MediaWiki page object.
|
||||
|
||||
Args:
|
||||
page: Page from a MediaWiki site.
|
||||
site: MediaWiki site (used to parse the sections of the page using the site template, if available).
|
||||
source_type: Source of the document.
|
||||
|
||||
Returns:
|
||||
Generated document.
|
||||
"""
|
||||
page_text = page.text
|
||||
sections_extracted: textlib.Content = textlib.extract_sections(page_text, site)
|
||||
|
||||
sections = [
|
||||
Section(
|
||||
link=f"{page.full_url()}#" + section.heading.replace(" ", "_"),
|
||||
text=section.title + section.content,
|
||||
)
|
||||
for section in sections_extracted.sections
|
||||
]
|
||||
sections.append(
|
||||
Section(
|
||||
link=page.full_url(),
|
||||
text=sections_extracted.header,
|
||||
)
|
||||
)
|
||||
|
||||
return Document(
|
||||
source=source_type,
|
||||
title=page.title(),
|
||||
doc_updated_at=pywikibot_timestamp_to_utc_datetime(
|
||||
page.latest_revision.timestamp
|
||||
),
|
||||
sections=sections,
|
||||
semantic_identifier=page.title(),
|
||||
metadata={"categories": [category.title() for category in page.categories()]},
|
||||
id=page.pageid,
|
||||
)
|
||||
|
||||
|
||||
class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
"""A connector for MediaWiki wikis.
|
||||
|
||||
Args:
|
||||
hostname: The hostname of the wiki.
|
||||
categories: The categories to include in the index.
|
||||
pages: The pages to include in the index.
|
||||
recurse_depth: The depth to recurse into categories. -1 means unbounded recursion.
|
||||
language_code: The language code of the wiki.
|
||||
batch_size: The batch size for loading documents.
|
||||
|
||||
Raises:
|
||||
ValueError: If `recurse_depth` is not an integer greater than or equal to -1.
|
||||
"""
|
||||
|
||||
document_source_type: ClassVar[DocumentSource] = DocumentSource.MEDIAWIKI
|
||||
"""DocumentSource type for all documents generated by instances of this class. Can be overridden for connectors
|
||||
tailored for specific sites."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hostname: str,
|
||||
categories: list[str],
|
||||
pages: list[str],
|
||||
recurse_depth: int,
|
||||
language_code: str = "en",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
if recurse_depth < -1:
|
||||
raise ValueError(
|
||||
f"recurse_depth must be an integer greater than or equal to -1. Got {recurse_depth} instead."
|
||||
)
|
||||
# -1 means infinite recursion, which `pywikibot` will only do with `True`
|
||||
self.recurse_depth: bool | int = True if recurse_depth == -1 else recurse_depth
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
# short names can only have ascii letters and digits
|
||||
|
||||
self.family = family_class_dispatch(hostname, "Wikipedia Connector")()
|
||||
self.site = pywikibot.Site(fam=self.family, code=language_code)
|
||||
self.categories = [
|
||||
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
|
||||
for category in categories
|
||||
]
|
||||
self.pages = [pywikibot.Page(self.site, page) for page in pages]
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load credentials for a MediaWiki site.
|
||||
|
||||
Note:
|
||||
For most read-only operations, MediaWiki API credentials are not necessary.
|
||||
This method can be overridden in the event that a particular MediaWiki site
|
||||
requires credentials.
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_doc_batch(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
"""Request batches of pages from a MediaWiki site.
|
||||
|
||||
Args:
|
||||
start: The beginning of the time period of pages to request.
|
||||
end: The end of the time period of pages to request.
|
||||
|
||||
Yields:
|
||||
Lists of Documents containing each parsed page in a batch.
|
||||
"""
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
# Pywikibot can handle batching for us, including only loading page contents when we finally request them.
|
||||
category_pages = [
|
||||
pagegenerators.PreloadingGenerator(
|
||||
pagegenerators.EdittimeFilterPageGenerator(
|
||||
pagegenerators.CategorizedPageGenerator(
|
||||
category, recurse=self.recurse_depth
|
||||
),
|
||||
last_edit_start=datetime.datetime.fromtimestamp(start)
|
||||
if start
|
||||
else None,
|
||||
last_edit_end=datetime.datetime.fromtimestamp(end) if end else None,
|
||||
),
|
||||
groupsize=self.batch_size,
|
||||
)
|
||||
for category in self.categories
|
||||
]
|
||||
|
||||
# Since we can specify both individual pages and categories, we need to iterate over all of them.
|
||||
all_pages = itertools.chain(self.pages, *category_pages)
|
||||
for page in all_pages:
|
||||
doc_batch.append(
|
||||
get_doc_from_page(page, self.site, self.document_source_type)
|
||||
)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Load all documents from the source.
|
||||
|
||||
Returns:
|
||||
A generator of documents.
|
||||
"""
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll the source for new documents.
|
||||
|
||||
Args:
|
||||
start: The start of the time range to poll.
|
||||
end: The end of the time range to poll.
|
||||
|
||||
Returns:
|
||||
A generator of documents.
|
||||
"""
|
||||
return self._get_doc_batch(start, end)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
HOSTNAME = "fallout.fandom.com"
|
||||
test_connector = MediaWikiConnector(
|
||||
hostname=HOSTNAME,
|
||||
categories=["Fallout:_New_Vegas_factions"],
|
||||
pages=["Fallout: New Vegas"],
|
||||
recurse_depth=1,
|
||||
)
|
||||
|
||||
all_docs = list(test_connector.load_from_state())
|
||||
print("All docs", all_docs)
|
||||
current = datetime.datetime.now().timestamp()
|
||||
one_day_ago = current - 30 * 24 * 60 * 60 # 30 days
|
||||
latest_docs = list(test_connector.poll_source(one_day_ago, current))
|
||||
print("Latest docs", latest_docs)
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import INDEX_SEPARATOR
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.utils.text_processing import make_url_compatible
|
||||
|
||||
|
||||
@@ -13,6 +14,7 @@ class InputType(str, Enum):
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
PRUNE = "prune"
|
||||
|
||||
|
||||
class ConnectorMissingCredentialError(PermissionError):
|
||||
@@ -112,11 +114,18 @@ class DocumentBase(BaseModel):
|
||||
title: str | None = None
|
||||
from_ingestion_api: bool = False
|
||||
|
||||
def get_title_for_document_index(self) -> str | None:
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
) -> str | None:
|
||||
# If title is explicitly empty, return a None here for embedding purposes
|
||||
if self.title == "":
|
||||
return None
|
||||
return self.semantic_identifier if self.title is None else self.title
|
||||
replace_chars = set(RETURN_SEPARATOR)
|
||||
title = self.semantic_identifier if self.title is None else self.title
|
||||
for char in replace_chars:
|
||||
title = title.replace(char, " ")
|
||||
title = title.strip()
|
||||
return title
|
||||
|
||||
def get_metadata_str_attributes(self) -> list[str] | None:
|
||||
if not self.metadata:
|
||||
|
||||
@@ -368,7 +368,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
compare_time = time.mktime(
|
||||
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
|
||||
)
|
||||
if compare_time <= end or compare_time > start:
|
||||
if compare_time > start and compare_time <= end:
|
||||
filtered_pages += [NotionPage(**page)]
|
||||
return filtered_pages
|
||||
|
||||
|
||||
@@ -207,7 +207,7 @@ class ProductboardConnector(PollConnector):
|
||||
):
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Unable to find updated_at for document '{document.id}'")
|
||||
logger.debug(f"Unable to find updated_at for document '{document.id}'")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
0
backend/danswer/connectors/salesforce/__init__.py
Normal file
0
backend/danswer/connectors/salesforce/__init__.py
Normal file
274
backend/danswer/connectors/salesforce/connector.py
Normal file
274
backend/danswer/connectors/salesforce/connector.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from simple_salesforce import Salesforce
|
||||
from simple_salesforce import SFType
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.salesforce.utils import extract_dict_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
requested_objects: list[str] = [],
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.sf_client: Salesforce | None = None
|
||||
self.parent_object_list = (
|
||||
[obj.capitalize() for obj in requested_objects]
|
||||
if requested_objects
|
||||
else DEFAULT_PARENT_OBJECT_TYPES
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.sf_client = Salesforce(
|
||||
username=credentials["sf_username"],
|
||||
password=credentials["sf_password"],
|
||||
security_token=credentials["sf_security_token"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _get_sf_type_object_json(self, type_name: str) -> Any:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
sf_object = SFType(
|
||||
type_name, self.sf_client.session_id, self.sf_client.sf_instance
|
||||
)
|
||||
return sf_object.describe()
|
||||
|
||||
def _get_name_from_id(self, id: str) -> str:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
try:
|
||||
user_object_info = self.sf_client.query(
|
||||
f"SELECT Name FROM User WHERE Id = '{id}'"
|
||||
)
|
||||
name = user_object_info.get("Records", [{}])[0].get("Name", "Null User")
|
||||
return name
|
||||
except Exception:
|
||||
logger.warning(f"Couldnt find name for object id: {id}")
|
||||
return "Null User"
|
||||
|
||||
def _convert_object_instance_to_document(
|
||||
self, object_dict: dict[str, Any]
|
||||
) -> Document:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
salesforce_id = object_dict["Id"]
|
||||
danswer_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_object_text = extract_dict_text(object_dict)
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
extracted_primary_owners = [
|
||||
BasicExpertInfo(
|
||||
display_name=self._get_name_from_id(object_dict["LastModifiedById"])
|
||||
)
|
||||
]
|
||||
|
||||
doc = Document(
|
||||
id=danswer_salesforce_id,
|
||||
sections=[Section(link=extracted_link, text=extracted_object_text)],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
doc_updated_at=extracted_doc_updated_at,
|
||||
primary_owners=extracted_primary_owners,
|
||||
metadata={},
|
||||
)
|
||||
return doc
|
||||
|
||||
def _is_valid_child_object(self, child_relationship: dict) -> bool:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
if not child_relationship["childSObject"]:
|
||||
return False
|
||||
if not child_relationship["relationshipName"]:
|
||||
return False
|
||||
|
||||
sf_type = child_relationship["childSObject"]
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
if not object_description["queryable"]:
|
||||
return False
|
||||
|
||||
try:
|
||||
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
||||
result = self.sf_client.query(query)
|
||||
if result["totalSize"] == 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
||||
return False
|
||||
|
||||
if child_relationship["field"]:
|
||||
if child_relationship["field"] == "RelatedToId":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
children_objects: list[dict] = []
|
||||
for child_relationship in object_description["childRelationships"]:
|
||||
if self._is_valid_child_object(child_relationship):
|
||||
children_objects.append(
|
||||
{
|
||||
"relationship_name": child_relationship["relationshipName"],
|
||||
"object_type": child_relationship["childSObject"],
|
||||
}
|
||||
)
|
||||
return children_objects
|
||||
|
||||
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
fields = [
|
||||
field.get("name")
|
||||
for field in object_description["fields"]
|
||||
if field.get("type", "base64") != "base64"
|
||||
]
|
||||
|
||||
return fields
|
||||
|
||||
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
|
||||
"""
|
||||
This function takes in an object_type and generates query(s) designed to grab
|
||||
information associated to objects of that type.
|
||||
It does that by getting all the fields of the parent object type.
|
||||
Then it gets all the child objects of that object type and all the fields of
|
||||
those children as well.
|
||||
"""
|
||||
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
|
||||
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
|
||||
|
||||
query = f"SELECT {', '.join(parent_fields)}"
|
||||
for child_object_dict in child_sf_types:
|
||||
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
|
||||
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
|
||||
|
||||
if len(query_addition) + len(query) > MAX_QUERY_LENGTH:
|
||||
query += f"\n FROM {parent_sf_type}"
|
||||
yield query
|
||||
query = "SELECT Id" + query_addition
|
||||
else:
|
||||
query += query_addition
|
||||
|
||||
query += f"\n FROM {parent_sf_type}"
|
||||
|
||||
yield query
|
||||
|
||||
def _fetch_from_salesforce(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
logger.debug(f"Processing: {parent_object_type}")
|
||||
|
||||
query_results: dict = {}
|
||||
for query in self._generate_query_per_parent_type(parent_object_type):
|
||||
if start is not None and end is not None:
|
||||
if start and start.tzinfo is None:
|
||||
start = start.replace(tzinfo=timezone.utc)
|
||||
if end and end.tzinfo is None:
|
||||
end = end.replace(tzinfo=timezone.utc)
|
||||
query += f" WHERE LastModifiedDate > {start.isoformat()} AND LastModifiedDate < {end.isoformat()}"
|
||||
|
||||
query_result = self.sf_client.query_all(query)
|
||||
|
||||
for record_dict in query_result["records"]:
|
||||
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
|
||||
|
||||
logger.info(
|
||||
f"Number of {parent_object_type} Objects processed: {len(query_results)}"
|
||||
)
|
||||
|
||||
for combined_object_dict in query_results.values():
|
||||
doc_batch.append(
|
||||
self._convert_object_instance_to_document(combined_object_dict)
|
||||
)
|
||||
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_salesforce()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
all_retrieved_ids: set[str] = set()
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
query_result = self.sf_client.query_all(query)
|
||||
all_retrieved_ids.update(
|
||||
f"{ID_PREFIX}{instance_dict.get('Id', '')}"
|
||||
for instance_dict in query_result["records"]
|
||||
)
|
||||
|
||||
return all_retrieved_ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = SalesforceConnector(
|
||||
requested_objects=os.environ["REQUESTED_OBJECTS"].split(",")
|
||||
)
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"sf_username": os.environ["SF_USERNAME"],
|
||||
"sf_password": os.environ["SF_PASSWORD"],
|
||||
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
66
backend/danswer/connectors/salesforce/utils.py
Normal file
66
backend/danswer/connectors/salesforce/utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
|
||||
|
||||
|
||||
def _clean_salesforce_dict(data: Union[dict, list]) -> Union[dict, list]:
|
||||
if isinstance(data, dict):
|
||||
if "records" in data.keys():
|
||||
data = data["records"]
|
||||
if isinstance(data, dict):
|
||||
if "attributes" in data.keys():
|
||||
if isinstance(data["attributes"], dict):
|
||||
data.update(data.pop("attributes"))
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {}
|
||||
for key, value in data.items():
|
||||
if not re.search(SF_JSON_FILTER, key, re.IGNORECASE):
|
||||
if "__c" in key: # remove the custom object indicator for display
|
||||
key = key[:-3]
|
||||
if isinstance(value, (dict, list)):
|
||||
filtered_value = _clean_salesforce_dict(value)
|
||||
if filtered_value: # Only add non-empty dictionaries or lists
|
||||
filtered_dict[key] = filtered_value
|
||||
elif value is not None:
|
||||
filtered_dict[key] = value
|
||||
return filtered_dict
|
||||
elif isinstance(data, list):
|
||||
filtered_list = []
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
filtered_item = _clean_salesforce_dict(item)
|
||||
if filtered_item: # Only add non-empty dictionaries or lists
|
||||
filtered_list.append(filtered_item)
|
||||
elif item is not None:
|
||||
filtered_list.append(filtered_item)
|
||||
return filtered_list
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def _json_to_natural_language(data: Union[dict, list], indent: int = 0) -> str:
|
||||
result = []
|
||||
indent_str = " " * indent
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
result.append(f"{indent_str}{key}:")
|
||||
result.append(_json_to_natural_language(value, indent + 2))
|
||||
else:
|
||||
result.append(f"{indent_str}{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result.append(_json_to_natural_language(item, indent))
|
||||
else:
|
||||
result.append(f"{indent_str}{data}")
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def extract_dict_text(raw_dict: dict) -> str:
|
||||
processed_dict = _clean_salesforce_dict(raw_dict)
|
||||
natural_language_dict = _json_to_natural_language(processed_dict)
|
||||
return natural_language_dict
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user