mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
656 Commits
v0.9.2
...
cloud_debu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09e6bd3c9c | ||
|
|
c1803cdd56 | ||
|
|
a5b9c76012 | ||
|
|
e9b10e8b41 | ||
|
|
a0fa4adb60 | ||
|
|
ca9ba925bd | ||
|
|
833cc5c97c | ||
|
|
23ecf654ed | ||
|
|
ddc6a6d2b3 | ||
|
|
571c8ece32 | ||
|
|
884bdb4b01 | ||
|
|
b3ecf0d59f | ||
|
|
f56fda27c9 | ||
|
|
b1e4d4ea8d | ||
|
|
8db6d49fe5 | ||
|
|
28598694b1 | ||
|
|
b5d0df90b9 | ||
|
|
48be6338ec | ||
|
|
ed9014f03d | ||
|
|
2dd51230ed | ||
|
|
8b249cbe63 | ||
|
|
6b50f86cd2 | ||
|
|
bd2805b6df | ||
|
|
2847ab003e | ||
|
|
1df6a506ec | ||
|
|
f1541d1fbe | ||
|
|
dd0c4b64df | ||
|
|
788b3015bc | ||
|
|
cbbf10f450 | ||
|
|
d954914a0a | ||
|
|
bee74ac360 | ||
|
|
29ef64272a | ||
|
|
01bf6ee4b7 | ||
|
|
0502417cbe | ||
|
|
d0483dd269 | ||
|
|
eefa872d60 | ||
|
|
3f3d4da611 | ||
|
|
469068052e | ||
|
|
9032b05606 | ||
|
|
334bc6be8c | ||
|
|
814f97c2c7 | ||
|
|
4f5a2b47c4 | ||
|
|
f545508268 | ||
|
|
590986ec65 | ||
|
|
531bab5409 | ||
|
|
29c44007c4 | ||
|
|
d388643a04 | ||
|
|
8a422683e3 | ||
|
|
ddc0230d68 | ||
|
|
6711e91dbf | ||
|
|
cff2346db5 | ||
|
|
8d3fad1f12 | ||
|
|
0c3dab8e8d | ||
|
|
47735e2044 | ||
|
|
1eeab8c773 | ||
|
|
e9b41bddc9 | ||
|
|
73a86b9019 | ||
|
|
12c426c87b | ||
|
|
06aeab6d59 | ||
|
|
9b7e67004c | ||
|
|
626ce74aa3 | ||
|
|
cec63465eb | ||
|
|
5f4b31d322 | ||
|
|
ab5e515a5a | ||
|
|
699a02902a | ||
|
|
c85157f734 | ||
|
|
824844bf84 | ||
|
|
a6ab8a8da4 | ||
|
|
40719eb542 | ||
|
|
e8c72f9e82 | ||
|
|
0ba77963c4 | ||
|
|
86f2892349 | ||
|
|
64f0ad8b26 | ||
|
|
616e997dad | ||
|
|
614bd378bb | ||
|
|
7064c3d06f | ||
|
|
3bb9e4bff6 | ||
|
|
3fec7a6a30 | ||
|
|
a01a9b9a99 | ||
|
|
21ec5ed795 | ||
|
|
54dcbfa288 | ||
|
|
c69b7fc941 | ||
|
|
6722e88a7b | ||
|
|
5b5e1eb7c7 | ||
|
|
87d97d13d5 | ||
|
|
4ae3b48938 | ||
|
|
dee1a0ecd7 | ||
|
|
ca172f3306 | ||
|
|
e5d0587efa | ||
|
|
a9516202fe | ||
|
|
d23fca96c4 | ||
|
|
a45724c899 | ||
|
|
34e250407a | ||
|
|
046c0fbe3e | ||
|
|
76595facef | ||
|
|
af2d548766 | ||
|
|
7c29b1e028 | ||
|
|
a52c821e78 | ||
|
|
0770a587f1 | ||
|
|
748b79b0ef | ||
|
|
9cacb373ef | ||
|
|
21967d4b6f | ||
|
|
f5d638161b | ||
|
|
0b5013b47d | ||
|
|
1b846fbf06 | ||
|
|
cae8a131a2 | ||
|
|
72b4e8e9fe | ||
|
|
c04e2f14d9 | ||
|
|
b40a12d5d7 | ||
|
|
5e7d454ebe | ||
|
|
238509c536 | ||
|
|
d7f8cf8f18 | ||
|
|
5d810d373e | ||
|
|
9455576078 | ||
|
|
71421bb782 | ||
|
|
b88cb388b7 | ||
|
|
639986001f | ||
|
|
e7a7e78969 | ||
|
|
e255ff7d23 | ||
|
|
1be2502112 | ||
|
|
f2bedb8fdd | ||
|
|
637404f482 | ||
|
|
daae146920 | ||
|
|
d95959fb41 | ||
|
|
c667d28e7a | ||
|
|
9e0b482f47 | ||
|
|
fa84eb657f | ||
|
|
264df3441b | ||
|
|
b9bad8b7a0 | ||
|
|
600ebb6432 | ||
|
|
09fe8ea868 | ||
|
|
ad6be03b4d | ||
|
|
65d2511216 | ||
|
|
113bf19c65 | ||
|
|
6026536110 | ||
|
|
056b671cd4 | ||
|
|
8d83ae2ee8 | ||
|
|
ca988f5c5f | ||
|
|
4e4214b82c | ||
|
|
fe83f676df | ||
|
|
6d6e12119b | ||
|
|
1f2b7cb9c8 | ||
|
|
878a189011 | ||
|
|
48c10271c2 | ||
|
|
c6a79d847e | ||
|
|
1bc3f8b96f | ||
|
|
7f6a6944d6 | ||
|
|
06f4146597 | ||
|
|
7ea73d5a5a | ||
|
|
30dfe6dcb4 | ||
|
|
dc5d5dfe05 | ||
|
|
0746e0be5b | ||
|
|
970320bd49 | ||
|
|
4a7bd5578e | ||
|
|
874b098a4b | ||
|
|
ce18b63eea | ||
|
|
7a919c3589 | ||
|
|
631bac4432 | ||
|
|
53428f6e9c | ||
|
|
53b3dcbace | ||
|
|
7a3c06c2d2 | ||
|
|
7a0d823c89 | ||
|
|
db69e445d6 | ||
|
|
18e63889b7 | ||
|
|
738e60c8ed | ||
|
|
8aec873e66 | ||
|
|
7c57dde8ab | ||
|
|
f30adab853 | ||
|
|
601687a522 | ||
|
|
350cf407c9 | ||
|
|
32ec4efc7a | ||
|
|
7c6981e052 | ||
|
|
c50cd20156 | ||
|
|
14772dee71 | ||
|
|
c81e704c95 | ||
|
|
3266ef6321 | ||
|
|
c89b98b4f2 | ||
|
|
e70e0ab859 | ||
|
|
69b6e9321e | ||
|
|
7e53af18b6 | ||
|
|
b9eb1ca2ba | ||
|
|
91d44c83d2 | ||
|
|
4dbc6bb4d1 | ||
|
|
4b6a4c6bbf | ||
|
|
fd1999454a | ||
|
|
0a35422d1d | ||
|
|
69b99056b2 | ||
|
|
2a55696545 | ||
|
|
ef9942b751 | ||
|
|
993acec5e9 | ||
|
|
b01a1b509a | ||
|
|
4f994124ef | ||
|
|
14863bd457 | ||
|
|
aa1c4c635a | ||
|
|
13f6e8a6b4 | ||
|
|
66f47d294c | ||
|
|
0a685bda7d | ||
|
|
23dc8b5dad | ||
|
|
cd5f2293ad | ||
|
|
6c2269e565 | ||
|
|
46315cddf1 | ||
|
|
5f28a1b0e4 | ||
|
|
9e9b7ed61d | ||
|
|
3fb2bfefec | ||
|
|
7c618c9d17 | ||
|
|
03e2789392 | ||
|
|
2783fa08a3 | ||
|
|
edeaee93a2 | ||
|
|
5385bae100 | ||
|
|
813445ab59 | ||
|
|
af814823c8 | ||
|
|
607f61eaeb | ||
|
|
de66f7adb2 | ||
|
|
3432d932d1 | ||
|
|
9bd0cb9eb5 | ||
|
|
f12eb4a5cf | ||
|
|
16863de0aa | ||
|
|
63d1eefee5 | ||
|
|
e338677896 | ||
|
|
7be80c4af9 | ||
|
|
7f1e4a02bf | ||
|
|
5be7d27285 | ||
|
|
fd84b7a768 | ||
|
|
36941ae663 | ||
|
|
212353ed4a | ||
|
|
eb8708f770 | ||
|
|
ac448956e9 | ||
|
|
634a0b9398 | ||
|
|
09d3e47c03 | ||
|
|
9c0cc94f15 | ||
|
|
07dfde2209 | ||
|
|
28e2b78b2e | ||
|
|
0553062ac6 | ||
|
|
284e375ba3 | ||
|
|
1f2f7d0ac2 | ||
|
|
2ecc28b57d | ||
|
|
77cf9b3539 | ||
|
|
076ce2ebd0 | ||
|
|
b625ee32a7 | ||
|
|
c32b93fcc3 | ||
|
|
1c8476072e | ||
|
|
7573416ca1 | ||
|
|
86d8666481 | ||
|
|
8abcde91d4 | ||
|
|
3466451d51 | ||
|
|
413891f143 | ||
|
|
7a0a4d4b79 | ||
|
|
a3439605a5 | ||
|
|
694e79f5e1 | ||
|
|
5dfafc8612 | ||
|
|
62a4aa10db | ||
|
|
a357cdc4c9 | ||
|
|
84615abfdd | ||
|
|
8ae6b1960b | ||
|
|
d9b87bbbc2 | ||
|
|
a0065b01af | ||
|
|
c5306148a3 | ||
|
|
1e17934de4 | ||
|
|
93add96ccc | ||
|
|
3a466a4b08 | ||
|
|
85cbd9caed | ||
|
|
9dc23bf3e7 | ||
|
|
e32809f7ca | ||
|
|
3e58f9f8ab | ||
|
|
2381c8d498 | ||
|
|
c6dadb24dc | ||
|
|
5dc07d4178 | ||
|
|
129c8f8faf | ||
|
|
67bfcabbc5 | ||
|
|
9819aa977a | ||
|
|
8d5b8a4028 | ||
|
|
682319d2e9 | ||
|
|
fe1400aa36 | ||
|
|
e3573b2bc1 | ||
|
|
35b5c44cc7 | ||
|
|
5eddc89b5a | ||
|
|
9a492ceb6d | ||
|
|
3c54ae9de9 | ||
|
|
13f08f3ebb | ||
|
|
bd9f15854f | ||
|
|
366aa2a8ea | ||
|
|
deee237c7e | ||
|
|
100b4a0d16 | ||
|
|
70207b4b39 | ||
|
|
50826b6bef | ||
|
|
3f648cbc31 | ||
|
|
c875a4774f | ||
|
|
049091eb01 | ||
|
|
3dac24542b | ||
|
|
194dcb593d | ||
|
|
bf291d0c0a | ||
|
|
8309f4a802 | ||
|
|
0ff2565125 | ||
|
|
e89dcd7f84 | ||
|
|
645e7e828e | ||
|
|
2a54f14195 | ||
|
|
9209fc804b | ||
|
|
b712877701 | ||
|
|
e6df32dcc3 | ||
|
|
eb81258a23 | ||
|
|
487ef4acc0 | ||
|
|
9b7cc83eae | ||
|
|
ce3124f9e4 | ||
|
|
e69303e309 | ||
|
|
6e698ac84a | ||
|
|
d69180aeb8 | ||
|
|
aa37051be9 | ||
|
|
a7d95661b3 | ||
|
|
33ee899408 | ||
|
|
954b5b2a56 | ||
|
|
521425a4f2 | ||
|
|
618bc02d54 | ||
|
|
b7de74fdf8 | ||
|
|
6e83fe3a39 | ||
|
|
259fc049b7 | ||
|
|
7015e6f2ab | ||
|
|
24be13c015 | ||
|
|
ddff7ecc3f | ||
|
|
97932dc44b | ||
|
|
637b6d9e75 | ||
|
|
54dc1ac917 | ||
|
|
21d5cc43f8 | ||
|
|
7c841051ed | ||
|
|
6e91964924 | ||
|
|
facf1d55a0 | ||
|
|
d68f8d6fbc | ||
|
|
65a205d488 | ||
|
|
485f3f72fa | ||
|
|
dcbea883ae | ||
|
|
a50a3944b3 | ||
|
|
60471b6a73 | ||
|
|
d703e694ce | ||
|
|
6066042fef | ||
|
|
eb0e20b9e4 | ||
|
|
490a68773b | ||
|
|
227aff1e47 | ||
|
|
6e29d1944c | ||
|
|
22189f02c6 | ||
|
|
fdc4811fce | ||
|
|
021d0cf314 | ||
|
|
942e47db29 | ||
|
|
f4a020b599 | ||
|
|
5166649eae | ||
|
|
ba805f766f | ||
|
|
9d57f34c34 | ||
|
|
cc2f584321 | ||
|
|
a1b95df3b8 | ||
|
|
9272d6ebfe | ||
|
|
4fb65dcf73 | ||
|
|
2bbc5d5d07 | ||
|
|
950b1c38f2 | ||
|
|
99fbfba32f | ||
|
|
0a59efe64a | ||
|
|
cf5d394d39 | ||
|
|
f6d8f5ca89 | ||
|
|
1fb4cdfcc3 | ||
|
|
ac51469bcb | ||
|
|
c25f164e28 | ||
|
|
813720905b | ||
|
|
0c45488ac6 | ||
|
|
95d9b33c1a | ||
|
|
55919f596c | ||
|
|
1d0fb6d012 | ||
|
|
2b1dbde829 | ||
|
|
2758ffd9d5 | ||
|
|
07a1b49b4f | ||
|
|
43d8daa5bc | ||
|
|
faeb9f09f0 | ||
|
|
25f5c12750 | ||
|
|
2d81710ccc | ||
|
|
187a7d2da2 | ||
|
|
4b152aa3a7 | ||
|
|
06f937cf93 | ||
|
|
5a24ed2947 | ||
|
|
2372e6a5a5 | ||
|
|
3eef4e3992 | ||
|
|
467ce4e3f3 | ||
|
|
ee4b334a0a | ||
|
|
4087292001 | ||
|
|
da6ed5b2b3 | ||
|
|
864ac2ac5c | ||
|
|
12cb77c80e | ||
|
|
583cd14bf4 | ||
|
|
001fcb3359 | ||
|
|
7ff18e0a93 | ||
|
|
9ac256e925 | ||
|
|
08600db41d | ||
|
|
6bf06ac7f7 | ||
|
|
5b06b53a3e | ||
|
|
afce57b29f | ||
|
|
257dbecd1d | ||
|
|
bd6baf39c3 | ||
|
|
b2c55ebd71 | ||
|
|
dea7a8f697 | ||
|
|
ddae2346ec | ||
|
|
9032fb4467 | ||
|
|
b6ecbbcf45 | ||
|
|
1d8e662b79 | ||
|
|
2cb33b1fb4 | ||
|
|
2cd1e6be00 | ||
|
|
8e55566f66 | ||
|
|
bafb95d920 | ||
|
|
c6e8bf2d28 | ||
|
|
c2d04f591d | ||
|
|
56c3a5ff5b | ||
|
|
fac2b100a1 | ||
|
|
51b79f688a | ||
|
|
a7002dfa1d | ||
|
|
93d0104d3c | ||
|
|
46e5ffa3ae | ||
|
|
d4f38bba8b | ||
|
|
19d6b63fd3 | ||
|
|
938d5788b6 | ||
|
|
70f703cc0f | ||
|
|
8bcf80aa76 | ||
|
|
5f5cc9a724 | ||
|
|
e4bb14d4e1 | ||
|
|
5d9b8364ab | ||
|
|
83c299ebc8 | ||
|
|
6b4143cc30 | ||
|
|
6e8c88ed71 | ||
|
|
d652cb3141 | ||
|
|
5e444d43f9 | ||
|
|
2e49027beb | ||
|
|
d7bcd32d9a | ||
|
|
4a6b8db65f | ||
|
|
6f440d126a | ||
|
|
013292a0e3 | ||
|
|
a1ae22ef4a | ||
|
|
40beda30a4 | ||
|
|
d3062cacea | ||
|
|
678ed23853 | ||
|
|
ea2da63cf2 | ||
|
|
4fc8a35220 | ||
|
|
f981106111 | ||
|
|
5439c33313 | ||
|
|
5e050f8305 | ||
|
|
12c82de78f | ||
|
|
645402c71a | ||
|
|
772313236f | ||
|
|
ecf4923a3a | ||
|
|
d66b81a902 | ||
|
|
753293cefb | ||
|
|
6d543f3d4f | ||
|
|
ccdc09e2d4 | ||
|
|
4a23c8702d | ||
|
|
dc2dfeb5b8 | ||
|
|
71d4fb98d3 | ||
|
|
b34f5862d7 | ||
|
|
0b08bf4e3f | ||
|
|
add87fa1b4 | ||
|
|
787fdf2e38 | ||
|
|
4499c630b3 | ||
|
|
e3be318781 | ||
|
|
231ab3fb5d | ||
|
|
ff9d7141a9 | ||
|
|
dba2d67cdb | ||
|
|
1a7d627949 | ||
|
|
f318e302c5 | ||
|
|
7384ca8768 | ||
|
|
73ee709801 | ||
|
|
53d2d333ab | ||
|
|
5be457e321 | ||
|
|
8223dc763d | ||
|
|
ea406c55cd | ||
|
|
ea80cdce02 | ||
|
|
40a0f71960 | ||
|
|
fcb94f1173 | ||
|
|
cc40f0d27b | ||
|
|
75dd103238 | ||
|
|
aafcf7af55 | ||
|
|
1201ed5ac0 | ||
|
|
a60613ec11 | ||
|
|
5640230f5b | ||
|
|
11d849b553 | ||
|
|
2eefb3c15f | ||
|
|
678ba41321 | ||
|
|
a40082c5da | ||
|
|
e5af4681d3 | ||
|
|
e05846db9f | ||
|
|
195e2c335d | ||
|
|
1dec69bb82 | ||
|
|
1d89fea73e | ||
|
|
075e4f18bc | ||
|
|
52bd1ad8ef | ||
|
|
5062075b8d | ||
|
|
e46facb765 | ||
|
|
f84e75cee7 | ||
|
|
b2d8e10339 | ||
|
|
d8ad3e73bf | ||
|
|
e2c4c07c34 | ||
|
|
7856718db8 | ||
|
|
3d9cc769d9 | ||
|
|
20e8c2287a | ||
|
|
57e5264df6 | ||
|
|
4c417b5e3e | ||
|
|
9270782c49 | ||
|
|
1a31f1e773 | ||
|
|
e28ba4b55b | ||
|
|
7ddfabed62 | ||
|
|
c7018f7a6c | ||
|
|
0fb6baef2b | ||
|
|
23988f8c49 | ||
|
|
1187849afe | ||
|
|
001801dee0 | ||
|
|
4a9966148d | ||
|
|
85c56f9942 | ||
|
|
07d76b2954 | ||
|
|
2a6c032883 | ||
|
|
e8dfed959e | ||
|
|
1f2be542f0 | ||
|
|
7dc06bfbe5 | ||
|
|
6f8e7abcbb | ||
|
|
18dcdd680d | ||
|
|
ad3df42b52 | ||
|
|
6568c7805a | ||
|
|
fa88c1dba8 | ||
|
|
7ea484aee2 | ||
|
|
dc7b367816 | ||
|
|
aea261d49e | ||
|
|
f27071cbc5 | ||
|
|
31a518a9d1 | ||
|
|
01463442ba | ||
|
|
53e916552b | ||
|
|
179dc418e0 | ||
|
|
a1bfa7847a | ||
|
|
e5494f9742 | ||
|
|
da3c5e3711 | ||
|
|
e5d84cae1b | ||
|
|
0c2cc7499f | ||
|
|
1261d859ac | ||
|
|
088551a4ef | ||
|
|
aa0f307cc7 | ||
|
|
e6bef573ba | ||
|
|
f6f9112b76 | ||
|
|
accdd580d7 | ||
|
|
4bcd65ed92 | ||
|
|
80f8d7a486 | ||
|
|
e8c28e79c9 | ||
|
|
b4bc6d994d | ||
|
|
ccc68c5c34 | ||
|
|
848d86b886 | ||
|
|
c0ab86bac2 | ||
|
|
8c2138a6ef | ||
|
|
9def9f0dba | ||
|
|
8023cafb2b | ||
|
|
5e01d6befb | ||
|
|
94edcac36e | ||
|
|
0ed77aa8a7 | ||
|
|
9b147ae437 | ||
|
|
bd63119684 | ||
|
|
a348caa9b1 | ||
|
|
76415aff41 | ||
|
|
84d551eda4 | ||
|
|
4ca38201d1 | ||
|
|
4a47e9a841 | ||
|
|
d7a30b01d2 | ||
|
|
9c0f927e16 | ||
|
|
55b9111410 | ||
|
|
07a4e112a4 | ||
|
|
b9781c43fb | ||
|
|
eaa8ae7399 | ||
|
|
a931494866 | ||
|
|
863f00f015 | ||
|
|
eae1dad0fa | ||
|
|
10b5b55658 | ||
|
|
b49a9ab171 | ||
|
|
9f50417109 | ||
|
|
94b4dc1656 | ||
|
|
4bce143d6e | ||
|
|
33eabf1b25 | ||
|
|
da979e5745 | ||
|
|
705b825580 | ||
|
|
32b595dfe1 | ||
|
|
2b9a751b96 | ||
|
|
1b6b134722 | ||
|
|
245adc4d3d | ||
|
|
4ad35d76b0 | ||
|
|
cc1e1c178b | ||
|
|
87b5975091 | ||
|
|
0545fb4443 | ||
|
|
b9fb657d81 | ||
|
|
14e75bbd24 | ||
|
|
3eb67baf5b | ||
|
|
8b72264535 | ||
|
|
786a46cbd0 | ||
|
|
7abbfa37bb | ||
|
|
85b56e39c9 | ||
|
|
143da5bc0d | ||
|
|
a1680fac2f | ||
|
|
5703ea47d2 | ||
|
|
9105f95d13 | ||
|
|
eccec6ab7c | ||
|
|
914da2e4cb | ||
|
|
e031576c87 | ||
|
|
bae794706c | ||
|
|
8f236a1288 | ||
|
|
6e9b6a1075 | ||
|
|
e4779c29a7 | ||
|
|
802086ee57 | ||
|
|
c516f3541c | ||
|
|
45d852a9db | ||
|
|
cee68106ef | ||
|
|
a24b465663 | ||
|
|
7ab0063dc6 | ||
|
|
dd2551040f | ||
|
|
f745ca1e03 | ||
|
|
eaaa135f90 | ||
|
|
457e7992a4 | ||
|
|
2fb1d06fbf | ||
|
|
8f9d4335ce | ||
|
|
ee1cb084ac | ||
|
|
2c77ad2aab | ||
|
|
f7d77a3c76 | ||
|
|
8b220d2dba | ||
|
|
6913efef90 | ||
|
|
12cbbe6cee | ||
|
|
55de519364 | ||
|
|
36134021c5 | ||
|
|
5b78299880 | ||
|
|
59364aadd7 | ||
|
|
e12785d277 | ||
|
|
7906d9edc8 | ||
|
|
6e54c97326 | ||
|
|
61424de531 | ||
|
|
4c2cf8b132 | ||
|
|
b169f78699 | ||
|
|
e48086b1c2 | ||
|
|
6b8ecb3a4b | ||
|
|
deb66a88aa | ||
|
|
90bd535c48 | ||
|
|
0de487064a | ||
|
|
114326d11a | ||
|
|
389c7b72db | ||
|
|
28ad01a51a | ||
|
|
0c102ebb5c | ||
|
|
5063b944ec | ||
|
|
15afe4dc78 | ||
|
|
a159779d39 | ||
|
|
44ebe3ae31 | ||
|
|
938a65628d | ||
|
|
5d390b65eb | ||
|
|
33974fc12c | ||
|
|
db0779dd02 | ||
|
|
f3fb7c572e | ||
|
|
0a0215ceee | ||
|
|
1a9921f63e | ||
|
|
a385234c0e | ||
|
|
65573210f1 | ||
|
|
c148fa5bfa | ||
|
|
11372aac8f | ||
|
|
f23a89ccfd | ||
|
|
e022e77b6d | ||
|
|
02cc211e91 | ||
|
|
bfe963988e | ||
|
|
0e6c2f0b51 |
30
.github/pull_request_template.md
vendored
30
.github/pull_request_template.md
vendored
@@ -6,20 +6,24 @@
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk
|
||||
[Any know risks or failure modes to point out to reviewers]
|
||||
## Accepted Risk (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Related Issue(s)
|
||||
[If applicable, link to the issue(s) this PR addresses]
|
||||
## Related Issue(s) (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Checklist:
|
||||
- [ ] All of the automated tests pass
|
||||
- [ ] All PR comments are addressed and marked resolved
|
||||
- [ ] If there are migrations, they have been rebased to latest main
|
||||
- [ ] If there are new dependencies, they are added to the requirements
|
||||
- [ ] If there are new environment variables, they are added to all of the deployment methods
|
||||
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- [ ] Docker images build and basic functionalities work
|
||||
- [ ] Author has done a final read through of the PR right before merge
|
||||
## Mental Checklist:
|
||||
- All of the automated tests pass
|
||||
- All PR comments are addressed and marked resolved
|
||||
- If there are migrations, they have been rebased to latest main
|
||||
- If there are new dependencies, they are added to the requirements
|
||||
- If there are new environment variables, they are added to all of the deployment methods
|
||||
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- Docker images build and basic functionalities work
|
||||
- Author has done a final read through of the PR right before merge
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
|
||||
@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
trivyignores: ./backend/.trivyignore
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL onyxdotapp/onyx-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
trivyignores: ./backend/.trivyignore
|
||||
|
||||
137
.github/workflows/docker-build-push-cloud-web-container-on-tag.yml
vendored
Normal file
137
.github/workflows/docker-build-push-cloud-web-container-on-tag.yml
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
name: Build and Push Cloud Web Image on Tag
|
||||
# Identical to the web container build, but with correct image tag and build args
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
|
||||
- run-id=${{ github.run_id }}
|
||||
- tag=platform-${{ matrix.platform }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NEXT_PUBLIC_CLOUD_ENABLED=true
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- build
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
@@ -3,53 +3,121 @@ name: Build and Push Model Server Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-model-server
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDKIT_PROGRESS: plain
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
build-amd64:
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: System Info
|
||||
run: |
|
||||
df -h
|
||||
free -h
|
||||
docker system prune -af --volumes
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: |
|
||||
image=moby/buildkit:latest
|
||||
network=host
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
- name: Build and Push AMD64
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
build-arm64:
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: System Info
|
||||
run: |
|
||||
df -h
|
||||
free -h
|
||||
docker system prune -af --volumes
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: |
|
||||
image=moby/buildkit:latest
|
||||
network=host
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and Push ARM64
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
merge-and-scan:
|
||||
needs: [build-amd64, build-arm64]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and Push Multi-arch Manifest
|
||||
run: |
|
||||
docker buildx create --use
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout: "10m"
|
||||
|
||||
@@ -3,12 +3,12 @@ name: Build and Push Web Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
@@ -27,11 +27,11 @@ jobs:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
@@ -40,16 +40,16 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
@@ -59,18 +59,18 @@ jobs:
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
@@ -90,42 +90,42 @@ jobs:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
34
.github/workflows/docker-tag-latest.yml
vendored
34
.github/workflows/docker-tag-latest.yml
vendored
@@ -7,31 +7,31 @@ on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'The version (ie v0.0.1) to tag as latest'
|
||||
description: "The version (ie v0.0.1) to tag as latest"
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
tag:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Enable Docker CLI experimental features
|
||||
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
|
||||
- name: Enable Docker CLI experimental features
|
||||
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
|
||||
|
||||
- name: Pull, Tag and Push Web Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t danswer/danswer-web-server:latest danswer/danswer-web-server:${{ github.event.inputs.version }}
|
||||
- name: Pull, Tag and Push Web Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
|
||||
|
||||
- name: Pull, Tag and Push API Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t danswer/danswer-backend:latest danswer/danswer-backend:${{ github.event.inputs.version }}
|
||||
- name: Pull, Tag and Push API Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
|
||||
|
||||
27
.github/workflows/hotfix-release-branches.yml
vendored
27
.github/workflows/hotfix-release-branches.yml
vendored
@@ -8,43 +8,42 @@ on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
hotfix_commit:
|
||||
description: 'Hotfix commit hash'
|
||||
description: "Hotfix commit hash"
|
||||
required: true
|
||||
hotfix_suffix:
|
||||
description: 'Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})'
|
||||
description: "Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})"
|
||||
required: true
|
||||
release_branch_pattern:
|
||||
description: 'Release branch pattern (regex)'
|
||||
description: "Release branch pattern (regex)"
|
||||
required: true
|
||||
default: 'release/.*'
|
||||
default: "release/.*"
|
||||
auto_merge:
|
||||
description: 'Automatically merge the hotfix PRs'
|
||||
description: "Automatically merge the hotfix PRs"
|
||||
required: true
|
||||
type: choice
|
||||
default: 'true'
|
||||
default: "true"
|
||||
options:
|
||||
- true
|
||||
- false
|
||||
|
||||
|
||||
jobs:
|
||||
hotfix_release_branches:
|
||||
permissions: write-all
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
|
||||
# needs RKUO_DEPLOY_KEY for write access to merge PR's
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
git config user.email "rkuo[bot]@onyx.app"
|
||||
|
||||
- name: Fetch All Branches
|
||||
run: |
|
||||
@@ -62,10 +61,10 @@ jobs:
|
||||
echo "No release branches found matching pattern '${{ github.event.inputs.release_branch_pattern }}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "Found release branches:"
|
||||
echo "$BRANCHES"
|
||||
|
||||
|
||||
# Join the branches into a single line separated by commas
|
||||
BRANCHES_JOINED=$(echo "$BRANCHES" | tr '\n' ',' | sed 's/,$//')
|
||||
|
||||
@@ -169,4 +168,4 @@ jobs:
|
||||
echo "Failed to merge pull request #$PR_NUMBER."
|
||||
fi
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
23
.github/workflows/nightly-close-stale-issues.yml
vendored
Normal file
23
.github/workflows/nightly-close-stale-issues.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: 'Nightly - Close stale issues and PRs'
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 11 * * *' # Runs every day at 3 AM PST / 4 AM PDT / 11 AM UTC
|
||||
|
||||
permissions:
|
||||
# contents: write # only for delete-branch option
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
close-issue-message: 'This issue was closed because it has been stalled for 90 days with no activity.'
|
||||
close-pr-message: 'This PR was closed because it has been stalled for 90 days with no activity.'
|
||||
days-before-stale: 75
|
||||
# days-before-close: 90 # uncomment after we test stale behavior
|
||||
|
||||
76
.github/workflows/nightly-scan-licenses.yml
vendored
Normal file
76
.github/workflows/nightly-scan-licenses.yml
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@v2
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: ${{ always() }}
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scanners: license
|
||||
format: table
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
# with:
|
||||
# sarif_file: trivy-results.sarif
|
||||
124
.github/workflows/pr-backport-autotrigger.yml
vendored
Normal file
124
.github/workflows/pr-backport-autotrigger.yml
vendored
Normal file
@@ -0,0 +1,124 @@
|
||||
name: Backport on Merge
|
||||
|
||||
# Note this workflow does not trigger the builds, be sure to manually tag the branches to trigger the builds
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed] # Later we check for merge so only PRs that go in can get backported
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
actions: write
|
||||
|
||||
jobs:
|
||||
backport:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.YUHONG_GH_ACTIONS }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@onyx.app"
|
||||
git fetch --prune
|
||||
|
||||
- name: Check for Backport Checkbox
|
||||
id: checkbox-check
|
||||
run: |
|
||||
PR_BODY="${{ github.event.pull_request.body }}"
|
||||
if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then
|
||||
echo "backport=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "backport=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: List and sort release branches
|
||||
id: list-branches
|
||||
run: |
|
||||
git fetch --all --tags
|
||||
BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr)
|
||||
BETA=$(echo "$BRANCHES" | head -n 1)
|
||||
STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1)
|
||||
echo "beta=release/$BETA" >> $GITHUB_OUTPUT
|
||||
echo "stable=release/$STABLE" >> $GITHUB_OUTPUT
|
||||
# Fetch latest tags for beta and stable
|
||||
LATEST_BETA_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*-beta.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$" | grep -v -- "-cloud" | sort -Vr | head -n 1)
|
||||
LATEST_STABLE_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+$" | sort -Vr | head -n 1)
|
||||
|
||||
# Handle case where no beta tags exist
|
||||
if [[ -z "$LATEST_BETA_TAG" ]]; then
|
||||
NEW_BETA_TAG="v1.0.0-beta.1"
|
||||
else
|
||||
NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 "." $3 "-beta." ($NF+1)}')
|
||||
fi
|
||||
|
||||
# Increment latest stable tag
|
||||
NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
|
||||
echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Echo branch and tag information
|
||||
run: |
|
||||
echo "Beta branch: ${{ steps.list-branches.outputs.beta }}"
|
||||
echo "Stable branch: ${{ steps.list-branches.outputs.stable }}"
|
||||
echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}"
|
||||
echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}"
|
||||
echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}"
|
||||
echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}"
|
||||
|
||||
- name: Trigger Backport
|
||||
if: steps.checkbox-check.outputs.backport == 'true'
|
||||
run: |
|
||||
set -e
|
||||
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
|
||||
|
||||
# Echo the merge commit SHA
|
||||
echo "Merge commit SHA: ${{ github.event.pull_request.merge_commit_sha }}"
|
||||
|
||||
# Fetch all history for all branches and tags
|
||||
git fetch --prune
|
||||
|
||||
# Reset and prepare the beta branch
|
||||
git checkout ${{ steps.list-branches.outputs.beta }}
|
||||
echo "Last 5 commits on beta branch:"
|
||||
git log -n 5 --pretty=format:"%H"
|
||||
echo "" # Newline for formatting
|
||||
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to beta failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Create new beta branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
# Push the changes and tag to the beta branch using PAT
|
||||
git push origin ${{ steps.list-branches.outputs.beta }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
|
||||
# Reset and prepare the stable branch
|
||||
git checkout ${{ steps.list-branches.outputs.stable }}
|
||||
echo "Last 5 commits on stable branch:"
|
||||
git log -n 5 --pretty=format:"%H"
|
||||
echo "" # Newline for formatting
|
||||
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to stable failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Create new stable branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
# Push the changes and tag to the stable branch using PAT
|
||||
git push origin ${{ steps.list-branches.outputs.stable }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
238
.github/workflows/pr-chromatic-tests.yml
vendored
Normal file
238
.github/workflows/pr-chromatic-tests.yml
vendored
Normal file
@@ -0,0 +1,238 @@
|
||||
name: Run Chromatic Tests
|
||||
concurrency:
|
||||
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on: push
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
name: Playwright Tests
|
||||
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Install playwright browsers
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
|
||||
- name: Build Web Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-web-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
GEN_AI_API_KEY=${{ secrets.OPENAI_API_KEY }} \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run pytest playwright test init
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTEST_IGNORE_SKIP: true
|
||||
run: pytest -s tests/integration/tests/playwright/test_playwright.py
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./web
|
||||
run: npx playwright test
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
# Chromatic automatically defaults to the test-results directory.
|
||||
# Replace with the path to your custom directory and adjust the CHROMATIC_ARCHIVE_LOCATION environment variable accordingly.
|
||||
name: test-results
|
||||
path: ./web/test-results
|
||||
retention-days: 30
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
chromatic-tests:
|
||||
name: Chromatic Tests
|
||||
|
||||
needs: playwright-tests
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Download Playwright test results
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: test-results
|
||||
path: ./web/test-results
|
||||
|
||||
- name: Run Chromatic
|
||||
uses: chromaui/action@latest
|
||||
with:
|
||||
playwright: true
|
||||
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
workingDir: ./web
|
||||
env:
|
||||
CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
72
.github/workflows/pr-helm-chart-testing.yml
vendored
Normal file
72
.github/workflows/pr-helm-chart-testing.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
echo "default_branch: ${{ github.event.repository.default_branch }}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
@@ -1,68 +0,0 @@
|
||||
# This workflow is intentionally disabled while we're still working on it
|
||||
# It's close to ready, but a race condition needs to be fixed with
|
||||
# API server and Vespa startup, and it needs to have a way to build/test against
|
||||
# local containers
|
||||
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Run chart-testing (lint)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --config ct.yaml
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
@@ -8,16 +8,19 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
- "release/**"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -33,21 +36,21 @@ jobs:
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull danswer/danswer-web-server:latest
|
||||
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:test
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
@@ -55,7 +58,7 @@ jobs:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-backend:test
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
@@ -67,7 +70,7 @@ jobs:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-model-server:test
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
@@ -79,12 +82,62 @@ jobs:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-integration:test
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# Start containers for multi-tenant tests
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
id: run_multitenant_tests
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
@@ -99,12 +152,12 @@ jobs:
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
@@ -130,7 +183,7 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run integration tests
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
@@ -144,8 +197,13 @@ jobs:
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
@@ -158,13 +216,19 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
@@ -18,6 +18,14 @@ env:
|
||||
# Jira
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
# Google
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -15,7 +15,7 @@ env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
model-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
|
||||
79
.github/workflows/tag-nightly.yml
vendored
79
.github/workflows/tag-nightly.yml
vendored
@@ -2,53 +2,52 @@ name: Nightly Tag Push
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 10 * * *' # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
|
||||
- cron: "0 10 * * *" # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
|
||||
|
||||
permissions:
|
||||
contents: write # Allows pushing tags to the repository
|
||||
contents: write # Allows pushing tags to the repository
|
||||
|
||||
jobs:
|
||||
create-and-push-tag:
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@onyx.app"
|
||||
|
||||
- name: Check for existing nightly tag
|
||||
id: check_tag
|
||||
run: |
|
||||
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
|
||||
echo "A tag starting with 'nightly-latest' already exists on HEAD."
|
||||
echo "tag_exists=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No tag starting with 'nightly-latest' exists on HEAD."
|
||||
echo "tag_exists=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# don't tag again if HEAD already has a nightly-latest tag on it
|
||||
- name: Create Nightly Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
env:
|
||||
DATE: ${{ github.run_id }}
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
echo "Creating tag: $TAG_NAME"
|
||||
git tag $TAG_NAME
|
||||
- name: Check for existing nightly tag
|
||||
id: check_tag
|
||||
run: |
|
||||
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
|
||||
echo "A tag starting with 'nightly-latest' already exists on HEAD."
|
||||
echo "tag_exists=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No tag starting with 'nightly-latest' exists on HEAD."
|
||||
echo "tag_exists=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Push Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
git push origin $TAG_NAME
|
||||
|
||||
# don't tag again if HEAD already has a nightly-latest tag on it
|
||||
- name: Create Nightly Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
env:
|
||||
DATE: ${{ github.run_id }}
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
echo "Creating tag: $TAG_NAME"
|
||||
git tag $TAG_NAME
|
||||
|
||||
- name: Push Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
git push origin $TAG_NAME
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,3 +7,4 @@
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
/web/test-results/
|
||||
288
.vscode/launch.template.jsonc
vendored
288
.vscode/launch.template.jsonc
vendored
@@ -6,19 +6,69 @@
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": [
|
||||
"--- Individual ---"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Danswer Services",
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Indexing",
|
||||
"Background Jobs",
|
||||
"Slack Bot"
|
||||
]
|
||||
}
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
@@ -29,7 +79,11 @@
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
"console": "integratedTerminal"
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
@@ -48,7 +102,11 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"9000"
|
||||
]
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Model Server Console"
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
@@ -64,18 +122,128 @@
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": [
|
||||
"danswer.main:app",
|
||||
"onyx.main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
]
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Indexing",
|
||||
"consoleName": "Indexing",
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"program": "onyx/onyxbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery primary Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
@@ -83,42 +251,46 @@
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
|
||||
{
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"--no-indexing"
|
||||
]
|
||||
},
|
||||
// For the listner to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "danswer/danswerbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
@@ -136,9 +308,23 @@
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
]
|
||||
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
@@ -147,7 +333,27 @@
|
||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true
|
||||
}
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
139
CONTRIBUTING.md
139
CONTRIBUTING.md
@@ -1,105 +1,113 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md"} -->
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md"} -->
|
||||
|
||||
# Contributing to Danswer
|
||||
Hey there! We are so excited that you're interested in Danswer.
|
||||
# Contributing to Onyx
|
||||
|
||||
Hey there! We are so excited that you're interested in Onyx.
|
||||
|
||||
As an open source project in a rapidly changing space, we welcome all contributions.
|
||||
|
||||
|
||||
## 💃 Guidelines
|
||||
|
||||
### Contribution Opportunities
|
||||
The [GitHub Issues](https://github.com/danswer-ai/danswer/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
Issues marked `good first issue` are an especially great place to start.
|
||||
|
||||
**Connectors** to other tools are another great place to contribute. For details on how, refer to this
|
||||
[README.md](https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/README.md).
|
||||
[README.md](https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/connectors/README.md).
|
||||
|
||||
If you have a new/different contribution in mind, we'd love to hear about it!
|
||||
Your input is vital to making sure that Danswer moves in the right direction.
|
||||
Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
And always feel free to message us (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
|
||||
And always feel free to message us (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
|
||||
To contribute to this project, please follow the
|
||||
["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
|
||||
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
|
||||
|
||||
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
|
||||
See the [Formatting and Linting](#-formatting-and-linting) section for how to run these checks locally.
|
||||
|
||||
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
|
||||
|
||||
### Getting Help 🙋
|
||||
|
||||
Our goal is to make contributing as easy as possible. If you run into any issues please don't hesitate to reach out.
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ)
|
||||
and
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
|
||||
and
|
||||
[Discord](https://discord.gg/TDJ59cGV2X).
|
||||
|
||||
We would love to see you there!
|
||||
|
||||
|
||||
## Get Started 🚀
|
||||
Danswer being a fully functional app, relies on some external software, specifically:
|
||||
|
||||
Onyx being a fully functional app, relies on some external software, specifically:
|
||||
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to build and run Danswer locally from source with Docker containers providing the above external software. We believe this combination is easier for
|
||||
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Danswer stack within Docker below.
|
||||
|
||||
> This guide provides instructions to build and run Onyx locally from source with Docker containers providing the above external software. We believe this combination is easier for
|
||||
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Onyx stack within Docker below.
|
||||
|
||||
### Local Set Up
|
||||
|
||||
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
|
||||
|
||||
If using a lower version, modifications will have to be made to the code.
|
||||
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
|
||||
|
||||
|
||||
#### Backend: Python requirements
|
||||
|
||||
Currently, we use pip and recommend creating a virtual environment.
|
||||
|
||||
For convenience here's a command for it:
|
||||
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs.
|
||||
> For simplicity, we recommend setting up the virtual environment outside of the danswer directory.
|
||||
> This virtual environment MUST NOT be set up WITHIN the onyx directory if you plan on using mypy within certain IDEs.
|
||||
> For simplicity, we recommend setting up the virtual environment outside of the onyx directory.
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
If using PowerShell, the command slightly differs:
|
||||
|
||||
```powershell
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r danswer/backend/requirements/default.txt
|
||||
pip install -r danswer/backend/requirements/dev.txt
|
||||
pip install -r danswer/backend/requirements/ee.txt
|
||||
pip install -r danswer/backend/requirements/model_server.txt
|
||||
pip install -r onyx/backend/requirements/default.txt
|
||||
pip install -r onyx/backend/requirements/dev.txt
|
||||
pip install -r onyx/backend/requirements/ee.txt
|
||||
pip install -r onyx/backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
@@ -109,42 +117,50 @@ You may have to deactivate and reactivate your virtualenv for `playwright` to ap
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||
Once the above is done, navigate to `danswer/web` run:
|
||||
Once the above is done, navigate to `onyx/web` run:
|
||||
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
#### Docker containers for external software
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
#### Running Onyx locally
|
||||
|
||||
To start the frontend, navigate to `onyx/web` and run:
|
||||
|
||||
#### Running Danswer locally
|
||||
To start the frontend, navigate to `danswer/web` and run:
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Next, start the model server which runs the local NLP models.
|
||||
Navigate to `danswer/backend` and run:
|
||||
Navigate to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
```
|
||||
|
||||
The first time running Danswer, you will need to run the DB migrations for Postgres.
|
||||
The first time running Onyx, you will need to run the DB migrations for Postgres.
|
||||
After the first time, this is no longer required unless the DB models change.
|
||||
|
||||
Navigate to `danswer/backend` and with the venv active, run:
|
||||
Navigate to `onyx/backend` and with the venv active, run:
|
||||
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
@@ -152,21 +168,24 @@ alembic upgrade head
|
||||
Next, start the task queue which orchestrates the background jobs.
|
||||
Jobs that take more time are run async from the API server.
|
||||
|
||||
Still in `danswer/backend`, run:
|
||||
Still in `onyx/backend`, run:
|
||||
|
||||
```bash
|
||||
python ./scripts/dev_run_background_jobs.py
|
||||
```
|
||||
|
||||
To run the backend API server, navigate back to `danswer/backend` and run:
|
||||
To run the backend API server, navigate back to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
|
||||
AUTH_TYPE=disabled uvicorn onyx.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:AUTH_TYPE='disabled'
|
||||
uvicorn danswer.main:app --reload --port 8080
|
||||
uvicorn onyx.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
|
||||
@@ -182,57 +201,61 @@ You should now have 4 servers running:
|
||||
- Model server
|
||||
- Background jobs
|
||||
|
||||
Now, visit `http://localhost:3000` in your browser. You should see the Danswer onboarding wizard where you can connect your external LLM provider to Danswer.
|
||||
Now, visit `http://localhost:3000` in your browser. You should see the Onyx onboarding wizard where you can connect your external LLM provider to Onyx.
|
||||
|
||||
You've successfully set up a local Danswer instance! 🏁
|
||||
You've successfully set up a local Onyx instance! 🏁
|
||||
|
||||
#### Running the Danswer application in a container
|
||||
#### Running the Onyx application in a container
|
||||
|
||||
You can run the full Danswer application stack from pre-built images including all external software dependencies.
|
||||
You can run the full Onyx application stack from pre-built images including all external software dependencies.
|
||||
|
||||
Navigate to `danswer/deployment/docker_compose` and run:
|
||||
Navigate to `onyx/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Danswer.
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Onyx.
|
||||
|
||||
If you want to make changes to Danswer and run those changes in Docker, you can also build a local version of the Danswer container images that incorporates your changes like so:
|
||||
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes like so:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d --build
|
||||
```
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
#### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `danswer/backend` directory, run:
|
||||
Then, from the `onyx/backend` directory, run:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Danswer is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
|
||||
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
|
||||
|
||||
#### Web
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `danswer/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `danswer/web` directory.
|
||||
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
Please double check that prettier passes before creating a pull request.
|
||||
|
||||
|
||||
### Release Process
|
||||
Danswer loosely follows the SemVer versioning standard.
|
||||
|
||||
Onyx loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).
|
||||
You can see the containers [here](https://hub.docker.com/search?q=onyx%2F).
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
## Some additional notes for Mac Users
|
||||
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md).
|
||||
|
||||
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Setting up Python
|
||||
|
||||
Ensure [Homebrew](https://brew.sh/) is already set up.
|
||||
|
||||
Then install python 3.11.
|
||||
|
||||
```bash
|
||||
brew install python@3.11
|
||||
```
|
||||
|
||||
Add python 3.11 to your path: add the following line to ~/.zshrc
|
||||
|
||||
```
|
||||
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
|
||||
```
|
||||
@@ -17,15 +21,16 @@ export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
|
||||
> **Note:**
|
||||
> You will need to open a new terminal for the path change above to take effect.
|
||||
|
||||
|
||||
### Setting up Docker
|
||||
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
|
||||
|
||||
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
|
||||
ensure it is running before continuing with the docker commands.
|
||||
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly.
|
||||
After installing pre-commit, run the following command:
|
||||
|
||||
```bash
|
||||
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
```
|
||||
```
|
||||
|
||||
6
LICENSE
6
LICENSE
@@ -2,9 +2,9 @@ Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
* All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
* All third party components incorporated into the Danswer Software are licensed under the original license provided by the owner of the applicable component.
|
||||
* Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
|
||||
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
172
README.md
172
README.md
@@ -1,129 +1,147 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
|
||||
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
<p align="center">Open Source Gen-AI Chat + Unified Search.</p>
|
||||
<p align="center">Open Source Gen-AI + Enterprise Search.</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
|
||||
</a>
|
||||
<a href="https://github.com/danswer-ai/danswer/blob/main/README.md" target="_blank">
|
||||
<a href="https://github.com/onyx-dot-app/onyx/blob/main/README.md" target="_blank">
|
||||
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Danswer](https://www.danswer.ai/)</strong> is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready
|
||||
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
|
||||
configuring Personas (AI Assistants) and their Prompts.
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
|
||||
configuring AI Assistants.
|
||||
|
||||
Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
|
||||
supported?" or "Where's the pull request for feature Y?"
|
||||
|
||||
<h3>Usage</h3>
|
||||
|
||||
Danswer Web App:
|
||||
Onyx Web App:
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
https://github.com/onyx-dot-app/onyx/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
|
||||
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
|
||||
|
||||
Or, plug Danswer into your existing Slack workflows (more integrations to come 😁):
|
||||
https://github.com/onyx-dot-app/onyx/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
|
||||
|
||||
|
||||
For more details on the Admin UI to manage connectors and users, check out our
|
||||
For more details on the Admin UI to manage connectors and users, check out our
|
||||
<strong><a href="https://www.youtube.com/watch?v=geNzY1nbCnU">Full Video Demo</a></strong>!
|
||||
|
||||
## Deployment
|
||||
|
||||
Danswer can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.danswer.dev/quickstart) to learn more.
|
||||
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/danswer-ai/danswer/tree/main/deployment/kubernetes).
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
|
||||
|
||||
## 💃 Main Features
|
||||
|
||||
## 💃 Main Features
|
||||
* Chat UI with the ability to select documents to chat with.
|
||||
* Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
* Connect Danswer with LLM of your choice (self-host for a fully airgapped solution).
|
||||
* Document Search + AI Answers for natural language queries.
|
||||
* Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
* Slack integration to get answers and search results directly in Slack.
|
||||
|
||||
- Chat UI with the ability to select documents to chat with.
|
||||
- Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
- Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
|
||||
- Document Search + AI Answers for natural language queries.
|
||||
- Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
- Slack integration to get answers and search results directly in Slack.
|
||||
|
||||
## 🚧 Roadmap
|
||||
* Chat/Prompt sharing with specific teammates and user groups.
|
||||
* Multi-Model model support, chat with images, video etc.
|
||||
* Choosing between LLMs and parameters during chat session.
|
||||
* Tool calling and agent configurations options.
|
||||
* Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
- Chat/Prompt sharing with specific teammates and user groups.
|
||||
- Multimodal model support, chat with images, video etc.
|
||||
- Choosing between LLMs and parameters during chat session.
|
||||
- Tool calling and agent configurations options.
|
||||
- Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
## Other Notable Benefits of Danswer
|
||||
* User Authentication with document level access management.
|
||||
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
* Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
* Custom deep learning models + learn from user feedback.
|
||||
* Easy deployment and ability to host Danswer anywhere of your choosing.
|
||||
## Other Notable Benefits of Onyx
|
||||
|
||||
- User Authentication with document level access management.
|
||||
- Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
- Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
- Custom deep learning models + learn from user feedback.
|
||||
- Easy deployment and ability to host Onyx anywhere of your choosing.
|
||||
|
||||
## 🔌 Connectors
|
||||
|
||||
Efficiently pulls the latest changes from:
|
||||
* Slack
|
||||
* GitHub
|
||||
* Google Drive
|
||||
* Confluence
|
||||
* Jira
|
||||
* Zendesk
|
||||
* Gmail
|
||||
* Notion
|
||||
* Gong
|
||||
* Slab
|
||||
* Linear
|
||||
* Productboard
|
||||
* Guru
|
||||
* Bookstack
|
||||
* Document360
|
||||
* Sharepoint
|
||||
* Hubspot
|
||||
* Local Files
|
||||
* Websites
|
||||
* And more ...
|
||||
|
||||
- Slack
|
||||
- GitHub
|
||||
- Google Drive
|
||||
- Confluence
|
||||
- Jira
|
||||
- Zendesk
|
||||
- Gmail
|
||||
- Notion
|
||||
- Gong
|
||||
- Slab
|
||||
- Linear
|
||||
- Productboard
|
||||
- Guru
|
||||
- Bookstack
|
||||
- Document360
|
||||
- Sharepoint
|
||||
- Hubspot
|
||||
- Local Files
|
||||
- Websites
|
||||
- And more ...
|
||||
|
||||
## 📚 Editions
|
||||
|
||||
There are two editions of Danswer:
|
||||
There are two editions of Onyx:
|
||||
|
||||
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
|
||||
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
* Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
* Role-based access control
|
||||
* Document permission inheritance from connected sources
|
||||
* Usage analytics and query history accessible to admins
|
||||
* Whitelabeling
|
||||
* API key authentication
|
||||
* Encryption of secrets
|
||||
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
|
||||
- Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
- Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
- Role-based access control
|
||||
- Document permission inheritance from connected sources
|
||||
- Usage analytics and query history accessible to admins
|
||||
- Whitelabeling
|
||||
- API key authentication
|
||||
- Encryption of secrets
|
||||
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
|
||||
To try the Danswer Enterprise Edition:
|
||||
To try the Onyx Enterprise Edition:
|
||||
|
||||
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
|
||||
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
||||
↑ Back to Top ↑
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
LABEL com.danswer.maintainer="founders@danswer.ai"
|
||||
LABEL com.danswer.description="This image is the web/frontend container of Danswer which \
|
||||
contains code for both the Community and Enterprise editions of Danswer. If you do not \
|
||||
LABEL com.danswer.maintainer="founders@onyx.app"
|
||||
LABEL com.danswer.description="This image is the web/frontend container of Onyx which \
|
||||
contains code for both the Community and Enterprise editions of Onyx. If you do not \
|
||||
have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \
|
||||
Edition features outside of personal development or testing purposes. Please reach out to \
|
||||
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
|
||||
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
# libpq-dev needed for psycopg (postgres)
|
||||
@@ -36,6 +37,8 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
@@ -53,7 +56,7 @@ RUN pip install --no-cache-dir --upgrade \
|
||||
# Cleanup for CVEs and size reduction
|
||||
# https://github.com/tornadoweb/tornado/issues/3107
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
|
||||
# perl-base is part of the base Python Debian image but not needed for Onyx functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
RUN apt-get update && \
|
||||
apt-get remove -y --allow-remove-essential \
|
||||
@@ -70,11 +73,11 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
@@ -89,7 +92,7 @@ COPY ./ee /app/ee
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# Set up application files
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./onyx /app/onyx
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
LABEL com.danswer.maintainer="founders@danswer.ai"
|
||||
LABEL com.danswer.description="This image is for the Danswer model server which runs all of the \
|
||||
AI models for Danswer. This container and all the code is MIT Licensed and free for all to use. \
|
||||
You can find it at https://hub.docker.com/r/danswer/danswer-model-server. For more details, \
|
||||
visit https://github.com/danswer-ai/danswer."
|
||||
LABEL com.danswer.maintainer="founders@onyx.app"
|
||||
LABEL com.danswer.description="This image is for the Onyx model server which runs all of the \
|
||||
AI models for Onyx. This container and all the code is MIT Licensed and free for all to use. \
|
||||
You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more details, \
|
||||
visit https://github.com/onyx-dot-app/onyx."
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
@@ -20,11 +20,11 @@ RUN pip install --no-cache-dir --upgrade \
|
||||
--timeout 30 \
|
||||
-r /tmp/requirements.txt
|
||||
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
# Download tokenizers, distilbert for the Danswer model
|
||||
# Download tokenizers, distilbert for the Onyx model
|
||||
# Download model weights
|
||||
# Run Nomic to pull in the custom architecture and have it cached locally
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
@@ -38,18 +38,18 @@ from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
|
||||
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
|
||||
# running Danswer, don't overwrite it with the built in cache folder
|
||||
# running Onyx, don't overwrite it with the built in cache folder
|
||||
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Utils used by model server
|
||||
COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
|
||||
COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
|
||||
|
||||
# Place to fetch version information
|
||||
COPY ./danswer/__init__.py /app/danswer/__init__.py
|
||||
COPY ./onyx/__init__.py /app/onyx/__init__.py
|
||||
|
||||
# Shared between Danswer Backend and Model Server
|
||||
# Shared between Onyx Backend and Model Server
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
|
||||
# Model Server main code
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/backend/alembic/README.md"} -->
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/backend/alembic/README.md"} -->
|
||||
|
||||
# Alembic DB Migrations
|
||||
These files are for creating/updating the tables in the Relational DB (Postgres).
|
||||
Danswer migrations use a generic single-database configuration with an async dbapi.
|
||||
|
||||
## To generate new migrations:
|
||||
run from danswer/backend:
|
||||
These files are for creating/updating the tables in the Relational DB (Postgres).
|
||||
Onyx migrations use a generic single-database configuration with an async dbapi.
|
||||
|
||||
## To generate new migrations:
|
||||
|
||||
run from onyx/backend:
|
||||
`alembic revision --autogenerate -m <DESCRIPTION_OF_MIGRATION>`
|
||||
|
||||
More info can be found here: https://alembic.sqlalchemy.org/en/latest/autogenerate.html
|
||||
|
||||
## Running migrations
|
||||
|
||||
To run all un-applied migrations:
|
||||
`alembic upgrade head`
|
||||
|
||||
To undo migrations:
|
||||
`alembic downgrade -X`
|
||||
`alembic downgrade -X`
|
||||
where X is the number of migrations you want to undo from the current state
|
||||
|
||||
@@ -1,35 +1,70 @@
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
from onyx.db.engine import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.base import Connection
|
||||
import os
|
||||
import ssl
|
||||
import asyncio
|
||||
import logging
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import Base
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_schema_options() -> tuple[str, bool]:
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
if USE_IAM_AUTH:
|
||||
if not os.path.exists(SSL_CERT_FILE):
|
||||
raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.")
|
||||
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
@@ -37,63 +72,32 @@ def get_schema_options() -> tuple[str, bool]:
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
x_args[key.strip()] = value.strip()
|
||||
schema_name = x_args.get("schema", "public")
|
||||
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
return schema_name, create_schema
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
|
||||
def include_object(
|
||||
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
"""
|
||||
schema_name, _ = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
schema_name, create_schema = get_schema_options()
|
||||
|
||||
if MULTI_TENANT and schema_name == "public":
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
and not upgrade_all_tenants
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
|
||||
"Please specify a tenant-specific schema."
|
||||
)
|
||||
|
||||
return schema_name, create_schema, upgrade_all_tenants
|
||||
|
||||
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
# Set search_path to the target schema
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
context.configure(
|
||||
@@ -111,16 +115,118 @@ def do_run_migrations(connection: Connection) -> None:
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
# Database connection settings
|
||||
region = AWS_REGION
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
|
||||
# Get IAM authentication token
|
||||
token = get_iam_auth_token(host, port, user, region)
|
||||
|
||||
# For Alembic / SQLAlchemy in this context, set SSL and password
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
connectable = create_async_engine(
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
build_connection_string(),
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
await connectable.dispose()
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
if upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
for schema in tenant_schemas:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema_name,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema_name}: {e}")
|
||||
raise
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
engine = create_async_engine(url)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic_offline(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
engine.sync_engine.dispose()
|
||||
|
||||
for schema in tenant_schemas:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
else:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy.sql import table
|
||||
from sqlalchemy.dialects import postgresql
|
||||
import json
|
||||
|
||||
from danswer.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0a98909f2757"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Introduce Danswer APIs
|
||||
"""Introduce Onyx APIs
|
||||
|
||||
Revision ID: 15326fcec57e
|
||||
Revises: 77d07dffae64
|
||||
@@ -8,7 +8,7 @@ Create Date: 2023-11-11 20:51:24.228999
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "15326fcec57e"
|
||||
@@ -0,0 +1,59 @@
|
||||
"""display custom llm models
|
||||
|
||||
Revision ID: 177de57c21c9
|
||||
Revises: 4ee1287bd26a
|
||||
Create Date: 2024-11-21 11:49:04.488677
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import and_
|
||||
|
||||
revision = "177de57c21c9"
|
||||
down_revision = "4ee1287bd26a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
llm_provider = sa.table(
|
||||
"llm_provider",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("provider", sa.String),
|
||||
sa.column("model_names", postgresql.ARRAY(sa.String)),
|
||||
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
|
||||
)
|
||||
|
||||
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
|
||||
|
||||
providers_to_update = sa.select(
|
||||
llm_provider.c.id,
|
||||
llm_provider.c.model_names,
|
||||
llm_provider.c.display_model_names,
|
||||
).where(
|
||||
and_(
|
||||
~llm_provider.c.provider.in_(excluded_providers),
|
||||
llm_provider.c.model_names.isnot(None),
|
||||
)
|
||||
)
|
||||
|
||||
results = conn.execute(providers_to_update).fetchall()
|
||||
|
||||
for provider_id, model_names, display_model_names in results:
|
||||
if display_model_names is None:
|
||||
display_model_names = []
|
||||
|
||||
combined_model_names = list(set(display_model_names + model_names))
|
||||
update_stmt = (
|
||||
llm_provider.update()
|
||||
.where(llm_provider.c.id == provider_id)
|
||||
.values(display_model_names=combined_model_names)
|
||||
)
|
||||
conn.execute(update_stmt)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add additional data to notifications
|
||||
|
||||
Revision ID: 1b10e1fda030
|
||||
Revises: 6756efa39ada
|
||||
Create Date: 2024-10-15 19:26:44.071259
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1b10e1fda030"
|
||||
down_revision = "6756efa39ada"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"notification", sa.Column("additional_data", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("notification", "additional_data")
|
||||
@@ -10,7 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1f60f60c3401"
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""default chosen assistants to none
|
||||
|
||||
Revision ID: 26b931506ecb
|
||||
Revises: 2daa494a0851
|
||||
Create Date: 2024-11-12 13:23:29.858995
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "26b931506ecb"
|
||||
down_revision = "2daa494a0851"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_new =
|
||||
CASE
|
||||
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chosen_assistants_old",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default="[-2, -1, 0]",
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_old =
|
||||
CASE
|
||||
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
|
||||
)
|
||||
30
backend/alembic/versions/2daa494a0851_add_group_sync_time.py
Normal file
30
backend/alembic/versions/2daa494a0851_add_group_sync_time.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""add-group-sync-time
|
||||
|
||||
Revision ID: 2daa494a0851
|
||||
Revises: c0fd6e4da83a
|
||||
Create Date: 2024-11-11 10:57:22.991157
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2daa494a0851"
|
||||
down_revision = "c0fd6e4da83a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_time_external_group_sync",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "last_time_external_group_sync")
|
||||
@@ -0,0 +1,50 @@
|
||||
"""single tool call per message
|
||||
|
||||
Revision ID: 33cb72ea4d80
|
||||
Revises: 5b29123cd710
|
||||
Create Date: 2024-11-01 12:51:01.535003
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33cb72ea4d80"
|
||||
down_revision = "5b29123cd710"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Delete extraneous ToolCall entries
|
||||
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM tool_call
|
||||
WHERE id NOT IN (
|
||||
SELECT MIN(id)
|
||||
FROM tool_call
|
||||
WHERE message_id IS NOT NULL
|
||||
GROUP BY message_id
|
||||
);
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Add a unique constraint on message_id
|
||||
op.create_unique_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
columns=["message_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Step 1: Drop the unique constraint on message_id
|
||||
op.drop_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
type_="unique",
|
||||
)
|
||||
121
backend/alembic/versions/35e518e0ddf4_properly_cascade.py
Normal file
121
backend/alembic/versions/35e518e0ddf4_properly_cascade.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""properly_cascade
|
||||
|
||||
Revision ID: 35e518e0ddf4
|
||||
Revises: 91a0a4d62b14
|
||||
Create Date: 2024-09-20 21:24:04.891018
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35e518e0ddf4"
|
||||
down_revision = "91a0a4d62b14"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Update chat_message foreign key constraint
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Update chat_message__search_doc foreign key constraints
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"search_doc",
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Add CASCADE delete for tool_call foreign key
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert chat_message foreign key constraint
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Revert chat_message__search_doc foreign key constraints
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"search_doc",
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Revert tool_call foreign key constraint
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["message_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -17,7 +17,7 @@ depends_on: None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
# At this point, we directly changed some previous migrations,
|
||||
# https://github.com/danswer-ai/danswer/pull/637
|
||||
# https://github.com/onyx-dot-app/onyx/pull/637
|
||||
# Due to using Postgres native Enums, it caused some complications for first time users.
|
||||
# To remove those complications, all Enums are only handled application side moving forward.
|
||||
# This migration exists to ensure that existing users don't run into upgrade issues.
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
"""add persona categories
|
||||
|
||||
Revision ID: 47e5bef3a1d7
|
||||
Revises: dfbe9e93d3c7
|
||||
Create Date: 2024-11-05 18:55:02.221064
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "47e5bef3a1d7"
|
||||
down_revision = "dfbe9e93d3c7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the persona_category table
|
||||
op.create_table(
|
||||
"persona_category",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
# Add category_id to persona table
|
||||
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"fk_persona_category",
|
||||
"persona",
|
||||
"persona_category",
|
||||
["category_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
|
||||
op.drop_column("persona", "category_id")
|
||||
op.drop_table("persona_category")
|
||||
@@ -0,0 +1,280 @@
|
||||
"""add_multiple_slack_bot_support
|
||||
|
||||
Revision ID: 4ee1287bd26a
|
||||
Revises: 47e5bef3a1d7
|
||||
Create Date: 2024-11-06 13:15:53.302644
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.db.models import SlackBot
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4ee1287bd26a"
|
||||
down_revision = "47e5bef3a1d7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
logger.info(f"{revision}: create_table: slack_bot")
|
||||
# Create new slack_bot table
|
||||
op.create_table(
|
||||
"slack_bot",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("enabled", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("bot_token", sa.LargeBinary(), nullable=False),
|
||||
sa.Column("app_token", sa.LargeBinary(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("bot_token"),
|
||||
sa.UniqueConstraint("app_token"),
|
||||
)
|
||||
|
||||
# # Create new slack_channel_config table
|
||||
op.create_table(
|
||||
"slack_channel_config",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("slack_bot_id", sa.Integer(), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("response_type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"enable_auto_filters", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["slack_bot_id"],
|
||||
["slack_bot.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Handle existing Slack bot tokens first
|
||||
logger.info(f"{revision}: Checking for existing Slack bot.")
|
||||
bot_token = None
|
||||
app_token = None
|
||||
first_row_id = None
|
||||
|
||||
try:
|
||||
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
|
||||
except Exception:
|
||||
logger.warning("No existing Slack bot tokens found.")
|
||||
tokens = {}
|
||||
|
||||
bot_token = tokens.get("bot_token")
|
||||
app_token = tokens.get("app_token")
|
||||
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Found bot and app tokens.")
|
||||
|
||||
session = Session(bind=op.get_bind())
|
||||
new_slack_bot = SlackBot(
|
||||
name="Slack Bot (Migrated)",
|
||||
enabled=True,
|
||||
bot_token=bot_token,
|
||||
app_token=app_token,
|
||||
)
|
||||
session.add(new_slack_bot)
|
||||
session.commit()
|
||||
first_row_id = new_slack_bot.id
|
||||
|
||||
# Create a default bot if none exists
|
||||
# This is in case there are no slack tokens but there are channels configured
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
|
||||
SELECT 'Default Bot', true, '', ''
|
||||
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
|
||||
RETURNING id;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Get the bot ID to use (either from existing migration or newly created)
|
||||
bot_id_query = sa.text(
|
||||
"""
|
||||
SELECT COALESCE(
|
||||
:first_row_id,
|
||||
(SELECT id FROM slack_bot ORDER BY id ASC LIMIT 1)
|
||||
) as bot_id;
|
||||
"""
|
||||
)
|
||||
result = op.get_bind().execute(bot_id_query, {"first_row_id": first_row_id})
|
||||
bot_id = result.scalar()
|
||||
|
||||
# CTE (Common Table Expression) that transforms the old slack_bot_config table data
|
||||
# This splits up the channel_names into their own rows
|
||||
channel_names_cte = """
|
||||
WITH channel_names AS (
|
||||
SELECT
|
||||
sbc.id as config_id,
|
||||
sbc.persona_id,
|
||||
sbc.response_type,
|
||||
sbc.enable_auto_filters,
|
||||
jsonb_array_elements_text(sbc.channel_config->'channel_names') as channel_name,
|
||||
sbc.channel_config->>'respond_tag_only' as respond_tag_only,
|
||||
sbc.channel_config->>'respond_to_bots' as respond_to_bots,
|
||||
sbc.channel_config->'respond_member_group_list' as respond_member_group_list,
|
||||
sbc.channel_config->'answer_filters' as answer_filters,
|
||||
sbc.channel_config->'follow_up_tags' as follow_up_tags
|
||||
FROM slack_bot_config sbc
|
||||
)
|
||||
"""
|
||||
|
||||
# Insert the channel names into the new slack_channel_config table
|
||||
insert_statement = """
|
||||
INSERT INTO slack_channel_config (
|
||||
slack_bot_id,
|
||||
persona_id,
|
||||
channel_config,
|
||||
response_type,
|
||||
enable_auto_filters
|
||||
)
|
||||
SELECT
|
||||
:bot_id,
|
||||
channel_name.persona_id,
|
||||
jsonb_build_object(
|
||||
'channel_name', channel_name.channel_name,
|
||||
'respond_tag_only',
|
||||
COALESCE((channel_name.respond_tag_only)::boolean, false),
|
||||
'respond_to_bots',
|
||||
COALESCE((channel_name.respond_to_bots)::boolean, false),
|
||||
'respond_member_group_list',
|
||||
COALESCE(channel_name.respond_member_group_list, '[]'::jsonb),
|
||||
'answer_filters',
|
||||
COALESCE(channel_name.answer_filters, '[]'::jsonb),
|
||||
'follow_up_tags',
|
||||
COALESCE(channel_name.follow_up_tags, '[]'::jsonb)
|
||||
),
|
||||
channel_name.response_type,
|
||||
channel_name.enable_auto_filters
|
||||
FROM channel_names channel_name;
|
||||
"""
|
||||
|
||||
op.execute(sa.text(channel_names_cte + insert_statement).bindparams(bot_id=bot_id))
|
||||
|
||||
# Clean up old tokens if they existed
|
||||
try:
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Removing old bot and app tokens.")
|
||||
get_kv_store().delete("slack_bot_tokens_config_key")
|
||||
except Exception:
|
||||
logger.warning("tried to delete tokens in dynamic config but failed")
|
||||
# Rename the table
|
||||
op.rename_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
"slack_channel_config__standard_answer_category",
|
||||
)
|
||||
|
||||
# Rename the column
|
||||
op.alter_column(
|
||||
"slack_channel_config__standard_answer_category",
|
||||
"slack_bot_config_id",
|
||||
new_column_name="slack_channel_config_id",
|
||||
)
|
||||
|
||||
# Drop the table with CASCADE to handle dependent objects
|
||||
op.execute("DROP TABLE slack_bot_config CASCADE")
|
||||
|
||||
logger.info(f"{revision}: Migration complete.")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Recreate the old slack_bot_config table
|
||||
op.create_table(
|
||||
"slack_bot_config",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("response_type", sa.String(), nullable=False),
|
||||
sa.Column("enable_auto_filters", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Migrate data back to the old format
|
||||
# Group by persona_id to combine channel names back into arrays
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_bot_config (
|
||||
persona_id,
|
||||
channel_config,
|
||||
response_type,
|
||||
enable_auto_filters
|
||||
)
|
||||
SELECT DISTINCT ON (persona_id)
|
||||
persona_id,
|
||||
jsonb_build_object(
|
||||
'channel_names', (
|
||||
SELECT jsonb_agg(c.channel_config->>'channel_name')
|
||||
FROM slack_channel_config c
|
||||
WHERE c.persona_id = scc.persona_id
|
||||
),
|
||||
'respond_tag_only', (channel_config->>'respond_tag_only')::boolean,
|
||||
'respond_to_bots', (channel_config->>'respond_to_bots')::boolean,
|
||||
'respond_member_group_list', channel_config->'respond_member_group_list',
|
||||
'answer_filters', channel_config->'answer_filters',
|
||||
'follow_up_tags', channel_config->'follow_up_tags'
|
||||
),
|
||||
response_type,
|
||||
enable_auto_filters
|
||||
FROM slack_channel_config scc
|
||||
WHERE persona_id IS NOT NULL;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Rename the table back
|
||||
op.rename_table(
|
||||
"slack_channel_config__standard_answer_category",
|
||||
"slack_bot_config__standard_answer_category",
|
||||
)
|
||||
|
||||
# Rename the column back
|
||||
op.alter_column(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
"slack_channel_config_id",
|
||||
new_column_name="slack_bot_config_id",
|
||||
)
|
||||
|
||||
# Try to save the first bot's tokens back to KV store
|
||||
try:
|
||||
first_bot = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.text(
|
||||
"SELECT bot_token, app_token FROM slack_bot ORDER BY id LIMIT 1"
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if first_bot and first_bot.bot_token and first_bot.app_token:
|
||||
tokens = {
|
||||
"bot_token": first_bot.bot_token,
|
||||
"app_token": first_bot.app_token,
|
||||
}
|
||||
get_kv_store().store("slack_bot_tokens_config_key", tokens)
|
||||
except Exception:
|
||||
logger.warning("Failed to save tokens back to KV store")
|
||||
|
||||
# Drop the new tables in reverse order
|
||||
op.drop_table("slack_channel_config")
|
||||
op.drop_table("slack_bot")
|
||||
23
backend/alembic/versions/54a74a0417fc_danswerbot_onyxbot.py
Normal file
23
backend/alembic/versions/54a74a0417fc_danswerbot_onyxbot.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""danswerbot -> onyxbot
|
||||
|
||||
Revision ID: 54a74a0417fc
|
||||
Revises: 94dc3d0236f8
|
||||
Create Date: 2024-12-11 18:05:05.490737
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "54a74a0417fc"
|
||||
down_revision = "94dc3d0236f8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("chat_session", "danswerbot_flow", new_column_name="onyxbot_flow")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("chat_session", "onyxbot_flow", new_column_name="danswerbot_flow")
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Track Danswerbot Explicitly
|
||||
"""Track Onyxbot Explicitly
|
||||
|
||||
Revision ID: 570282d33c49
|
||||
Revises: 7547d982db8f
|
||||
@@ -0,0 +1,70 @@
|
||||
"""nullable search settings for historic index attempts
|
||||
|
||||
Revision ID: 5b29123cd710
|
||||
Revises: 949b4a92a401
|
||||
Create Date: 2024-10-30 19:37:59.630704
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5b29123cd710"
|
||||
down_revision = "949b4a92a401"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Modify the column to be nullable
|
||||
op.alter_column(
|
||||
"index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
|
||||
# Add back the foreign key with ON DELETE SET NULL
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Warning: This will delete all index attempts that don't have search settings
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM index_attempt
|
||||
WHERE search_settings_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Modify the column to be not nullable
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
"search_settings_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Add back the foreign key without ON DELETE SET NULL
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add api_version and deployment_name to search settings
|
||||
|
||||
Revision ID: 5d12a446f5c0
|
||||
Revises: e4334d5b33ba
|
||||
Create Date: 2024-10-08 15:56:07.975636
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5d12a446f5c0"
|
||||
down_revision = "e4334d5b33ba"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("api_version", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("deployment_name", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("embedding_provider", "deployment_name")
|
||||
op.drop_column("embedding_provider", "api_version")
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Migrate chat_session and chat_message tables to use UUID primary keys
|
||||
|
||||
Revision ID: 6756efa39ada
|
||||
Revises: 5d12a446f5c0
|
||||
Create Date: 2024-10-15 17:47:44.108537
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "6756efa39ada"
|
||||
down_revision = "5d12a446f5c0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
"""
|
||||
This script:
|
||||
1. Adds UUID columns to chat_session and chat_message
|
||||
2. Populates new columns with UUIDs
|
||||
3. Updates foreign key relationships
|
||||
4. Removes old integer ID columns
|
||||
|
||||
Note: Downgrade will assign new integer IDs, not restore original ones.
|
||||
"""
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
|
||||
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column(
|
||||
"new_id",
|
||||
sa.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
op.execute("UPDATE chat_session SET new_id = gen_random_uuid();")
|
||||
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("new_chat_session_id", sa.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET new_chat_session_id = cs.new_id
|
||||
FROM chat_session cs
|
||||
WHERE chat_message.chat_session_id = cs.id;
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.drop_column("chat_message", "chat_session_id")
|
||||
op.alter_column(
|
||||
"chat_message", "new_chat_session_id", new_column_name="chat_session_id"
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
|
||||
op.drop_column("chat_session", "id")
|
||||
op.alter_column("chat_session", "new_id", new_column_name="id")
|
||||
|
||||
op.create_primary_key("chat_session_pkey", "chat_session", ["id"])
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("old_id", sa.Integer, autoincrement=True, nullable=True),
|
||||
)
|
||||
|
||||
op.execute("CREATE SEQUENCE chat_session_old_id_seq OWNED BY chat_session.old_id;")
|
||||
op.execute(
|
||||
"ALTER TABLE chat_session ALTER COLUMN old_id SET DEFAULT nextval('chat_session_old_id_seq');"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"UPDATE chat_session SET old_id = nextval('chat_session_old_id_seq') WHERE old_id IS NULL;"
|
||||
)
|
||||
|
||||
op.alter_column("chat_session", "old_id", nullable=False)
|
||||
|
||||
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
|
||||
op.create_primary_key("chat_session_pkey", "chat_session", ["old_id"])
|
||||
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("old_chat_session_id", sa.Integer, nullable=True),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET old_chat_session_id = cs.old_id
|
||||
FROM chat_session cs
|
||||
WHERE chat_message.chat_session_id = cs.id;
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("chat_message", "chat_session_id")
|
||||
op.alter_column(
|
||||
"chat_message", "old_chat_session_id", new_column_name="chat_session_id"
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["old_id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_column("chat_session", "id")
|
||||
op.alter_column("chat_session", "old_id", new_column_name="id")
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"id",
|
||||
type_=sa.Integer(),
|
||||
existing_type=sa.Integer(),
|
||||
existing_nullable=False,
|
||||
existing_server_default=False,
|
||||
)
|
||||
|
||||
# Rename the sequence
|
||||
op.execute("ALTER SEQUENCE chat_session_old_id_seq RENAME TO chat_session_id_seq;")
|
||||
|
||||
# Update the default value to use the renamed sequence
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"id",
|
||||
server_default=sa.text("nextval('chat_session_id_seq'::regclass)"),
|
||||
)
|
||||
45
backend/alembic/versions/6d562f86c78b_remove_default_bot.py
Normal file
45
backend/alembic/versions/6d562f86c78b_remove_default_bot.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""remove default bot
|
||||
|
||||
Revision ID: 6d562f86c78b
|
||||
Revises: 177de57c21c9
|
||||
Create Date: 2024-11-22 11:51:29.331336
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6d562f86c78b"
|
||||
down_revision = "177de57c21c9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM slack_bot
|
||||
WHERE name = 'Default Bot'
|
||||
AND bot_token = ''
|
||||
AND app_token = ''
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM slack_channel_config
|
||||
WHERE slack_channel_config.slack_bot_id = slack_bot.id
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
|
||||
SELECT 'Default Bot', true, '', ''
|
||||
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
|
||||
RETURNING id;
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -9,7 +9,7 @@ import json
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "703313b75876"
|
||||
|
||||
@@ -8,9 +8,9 @@ Create Date: 2024-03-22 21:34:27.629444
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -18,7 +18,7 @@ depends_on: None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
# In a PR:
|
||||
# https://github.com/danswer-ai/danswer/pull/397/files#diff-f05fb341f6373790b91852579631b64ca7645797a190837156a282b67e5b19c2
|
||||
# https://github.com/onyx-dot-app/onyx/pull/397/files#diff-f05fb341f6373790b91852579631b64ca7645797a190837156a282b67e5b19c2
|
||||
# we directly changed some previous migrations. This caused some users to have native enums
|
||||
# while others wouldn't. This has caused some issues when adding new fields to these enums.
|
||||
# This migration manually changes the enum types to ensure that nobody uses native enums.
|
||||
|
||||
45
backend/alembic/versions/91a0a4d62b14_milestone.py
Normal file
45
backend/alembic/versions/91a0a4d62b14_milestone.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Milestone
|
||||
|
||||
Revision ID: 91a0a4d62b14
|
||||
Revises: dab04867cd88
|
||||
Create Date: 2024-12-13 19:03:30.947551
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91a0a4d62b14"
|
||||
down_revision = "dab04867cd88"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"milestone",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("tenant_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("event_type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("milestone")
|
||||
@@ -7,7 +7,7 @@ Create Date: 2024-03-21 12:05:23.956734
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91fd3b470d1a"
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""add web ui option to slack config
|
||||
|
||||
Revision ID: 93560ba1b118
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-24 06:36:17.490612
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93560ba1b118"
|
||||
down_revision = "6d562f86c78b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add show_continue_in_web_ui with default False to all existing channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
|
||||
WHERE NOT channel_config ? 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove show_continue_in_web_ui from all channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config - 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
72
backend/alembic/versions/949b4a92a401_remove_rt.py
Normal file
72
backend/alembic/versions/949b4a92a401_remove_rt.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""remove rt
|
||||
|
||||
Revision ID: 949b4a92a401
|
||||
Revises: 1b10e1fda030
|
||||
Create Date: 2024-10-26 13:06:06.937969
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
|
||||
# Import your models and constants
|
||||
from onyx.db.models import (
|
||||
Connector,
|
||||
ConnectorCredentialPair,
|
||||
Credential,
|
||||
IndexAttempt,
|
||||
)
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "949b4a92a401"
|
||||
down_revision = "1b10e1fda030"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Deletes all RequestTracker connectors and associated data
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
# Get connectors using raw SQL
|
||||
result = bind.execute(
|
||||
text("SELECT id FROM connector WHERE source = 'requesttracker'")
|
||||
)
|
||||
connector_ids = [row[0] for row in result]
|
||||
|
||||
if connector_ids:
|
||||
cc_pairs_to_delete = (
|
||||
session.query(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.connector_id.in_(connector_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
cc_pair_ids = [cc_pair.id for cc_pair in cc_pairs_to_delete]
|
||||
|
||||
if cc_pair_ids:
|
||||
session.query(IndexAttempt).filter(
|
||||
IndexAttempt.connector_credential_pair_id.in_(cc_pair_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
session.query(ConnectorCredentialPair).filter(
|
||||
ConnectorCredentialPair.id.in_(cc_pair_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
credential_ids = [cc_pair.credential_id for cc_pair in cc_pairs_to_delete]
|
||||
if credential_ids:
|
||||
session.query(Credential).filter(Credential.id.in_(credential_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
session.query(Connector).filter(Connector.id.in_(connector_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# No-op downgrade as we cannot restore deleted data
|
||||
pass
|
||||
@@ -0,0 +1,30 @@
|
||||
"""make document set description optional
|
||||
|
||||
Revision ID: 94dc3d0236f8
|
||||
Revises: bf7a81109301
|
||||
Create Date: 2024-12-11 11:26:10.616722
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "94dc3d0236f8"
|
||||
down_revision = "bf7a81109301"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make document_set.description column nullable
|
||||
op.alter_column(
|
||||
"document_set", "description", existing_type=sa.String(), nullable=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert document_set.description column to non-nullable
|
||||
op.alter_column(
|
||||
"document_set", "description", existing_type=sa.String(), nullable=False
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add creator to cc pair
|
||||
|
||||
Revision ID: 9cf5c00f72fe
|
||||
Revises: 26b931506ecb
|
||||
Create Date: 2024-11-12 15:16:42.682902
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9cf5c00f72fe"
|
||||
down_revision = "26b931506ecb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"creator_id",
|
||||
sa.UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "creator_id")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Combine Search and Chat
|
||||
|
||||
Revision ID: 9f696734098f
|
||||
Revises: a8c2065484e6
|
||||
Create Date: 2024-11-27 15:32:19.694972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9f696734098f"
|
||||
down_revision = "a8c2065484e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("chat_session", "description", nullable=True)
|
||||
op.drop_column("chat_session", "one_shot")
|
||||
op.drop_column("slack_channel_config", "response_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
|
||||
op.alter_column("chat_session", "description", nullable=False)
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
|
||||
)
|
||||
op.add_column(
|
||||
"slack_channel_config",
|
||||
sa.Column(
|
||||
"response_type", sa.String(), nullable=False, server_default="citations"
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add auto scroll to user model
|
||||
|
||||
Revision ID: a8c2065484e6
|
||||
Revises: abe7378b8217
|
||||
Create Date: 2024-11-22 17:34:09.690295
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a8c2065484e6"
|
||||
down_revision = "abe7378b8217"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "auto_scroll")
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add indexing trigger to cc_pair
|
||||
|
||||
Revision ID: abe7378b8217
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-26 19:09:53.481171
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abe7378b8217"
|
||||
down_revision = "93560ba1b118"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"indexing_trigger",
|
||||
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "indexing_trigger")
|
||||
@@ -31,6 +31,12 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# First, update any null values to a default value
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL"
|
||||
)
|
||||
|
||||
# Then, make the column non-nullable
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
|
||||
@@ -10,7 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import ENUM
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b156fa702355"
|
||||
@@ -288,6 +288,15 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
|
||||
# below
|
||||
op.execute("DELETE FROM chat_feedback")
|
||||
op.execute("DELETE FROM chat_message__search_doc")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM chat_message")
|
||||
op.execute("DELETE FROM chat_session")
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
"""remove description from starter messages
|
||||
|
||||
Revision ID: b72ed7a5db0e
|
||||
Revises: 33cb72ea4d80
|
||||
Create Date: 2024-11-03 15:55:28.944408
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b72ed7a5db0e"
|
||||
down_revision = "33cb72ea4d80"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET starter_messages = (
|
||||
SELECT jsonb_agg(elem - 'description')
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET starter_messages = (
|
||||
SELECT jsonb_agg(elem || '{"description": ""}')
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
"""delete_input_prompts
|
||||
|
||||
Revision ID: bf7a81109301
|
||||
Revises: f7a894b06d02
|
||||
Create Date: 2024-12-09 12:00:49.884228
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bf7a81109301"
|
||||
down_revision = "f7a894b06d02"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
87
backend/alembic/versions/c0aab6edb6dd_delete_workspace.py
Normal file
87
backend/alembic/versions/c0aab6edb6dd_delete_workspace.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""delete workspace
|
||||
|
||||
Revision ID: c0aab6edb6dd
|
||||
Revises: 35e518e0ddf4
|
||||
Create Date: 2024-12-17 14:37:07.660631
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0aab6edb6dd"
|
||||
down_revision = "35e518e0ddf4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = connector_specific_config - 'workspace'
|
||||
WHERE source = 'SLACK'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
import json
|
||||
from sqlalchemy import text
|
||||
from slack_sdk import WebClient
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Fetch all Slack credentials
|
||||
creds_result = conn.execute(
|
||||
text("SELECT id, credential_json FROM credential WHERE source = 'SLACK'")
|
||||
)
|
||||
all_slack_creds = creds_result.fetchall()
|
||||
if not all_slack_creds:
|
||||
return
|
||||
|
||||
for cred_row in all_slack_creds:
|
||||
credential_id, credential_json = cred_row
|
||||
|
||||
credential_json = (
|
||||
credential_json.tobytes().decode("utf-8")
|
||||
if isinstance(credential_json, memoryview)
|
||||
else credential_json.decode("utf-8")
|
||||
)
|
||||
credential_data = json.loads(credential_json)
|
||||
slack_bot_token = credential_data.get("slack_bot_token")
|
||||
if not slack_bot_token:
|
||||
print(
|
||||
f"No slack_bot_token found for credential {credential_id}. "
|
||||
"Your Slack connector will not function until you upgrade and provide a valid token."
|
||||
)
|
||||
continue
|
||||
|
||||
client = WebClient(token=slack_bot_token)
|
||||
try:
|
||||
auth_response = client.auth_test()
|
||||
workspace = auth_response["url"].split("//")[1].split(".")[0]
|
||||
|
||||
# Update only the connectors linked to this credential
|
||||
# (and which are Slack connectors).
|
||||
op.execute(
|
||||
f"""
|
||||
UPDATE connector AS c
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{{workspace}}',
|
||||
to_jsonb('{workspace}'::text)
|
||||
)
|
||||
FROM connector_credential_pair AS ccp
|
||||
WHERE ccp.connector_id = c.id
|
||||
AND c.source = 'SLACK'
|
||||
AND ccp.credential_id = {credential_id}
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
f"We were unable to get the workspace url for your Slack Connector with id {credential_id}."
|
||||
)
|
||||
print("This connector will no longer work until you upgrade.")
|
||||
continue
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add recent assistants
|
||||
|
||||
Revision ID: c0fd6e4da83a
|
||||
Revises: b72ed7a5db0e
|
||||
Create Date: 2024-11-03 17:28:54.916618
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0fd6e4da83a"
|
||||
down_revision = "b72ed7a5db0e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
@@ -23,6 +23,56 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete chat messages and feedback first since they reference chat sessions
|
||||
# Get chat messages from sessions with null persona_id
|
||||
chat_messages_query = """
|
||||
SELECT id
|
||||
FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
|
||||
# Delete dependent records first
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM document_retrieval_feedback
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Delete chat messages
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Now we can safely delete the chat sessions
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Add composite index to document_by_connector_credential_pair
|
||||
|
||||
Revision ID: dab04867cd88
|
||||
Revises: 54a74a0417fc
|
||||
Create Date: 2024-12-13 22:43:20.119990
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dab04867cd88"
|
||||
down_revision = "54a74a0417fc"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Composite index on (connector_id, credential_id)
|
||||
op.create_index(
|
||||
"idx_document_cc_pair_connector_credential",
|
||||
"document_by_connector_credential_pair",
|
||||
["connector_id", "credential_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"idx_document_cc_pair_connector_credential",
|
||||
table_name="document_by_connector_credential_pair",
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Danswer Custom Tool Flow
|
||||
"""Onyx Custom Tool Flow
|
||||
|
||||
Revision ID: dba7f71618f5
|
||||
Revises: d5645c915d0e
|
||||
@@ -9,12 +9,12 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from danswer.db.search_settings import (
|
||||
from onyx.db.search_settings import (
|
||||
get_new_default_embedding_model,
|
||||
get_old_default_embedding_model,
|
||||
user_has_overridden_embedding_model,
|
||||
)
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dbaa756c2ccf"
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""extended_role_for_non_web
|
||||
|
||||
Revision ID: dfbe9e93d3c7
|
||||
Revises: 9cf5c00f72fe
|
||||
Create Date: 2024-11-16 07:54:18.727906
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dfbe9e93d3c7"
|
||||
down_revision = "9cf5c00f72fe"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET role = 'EXT_PERM_USER'
|
||||
WHERE has_web_login = false
|
||||
"""
|
||||
)
|
||||
op.drop_column("user", "has_web_login")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET has_web_login = false,
|
||||
role = 'BASIC'
|
||||
WHERE role IN ('SLACK_USER', 'EXT_PERM_USER')
|
||||
"""
|
||||
)
|
||||
@@ -8,7 +8,7 @@ Create Date: 2024-03-14 18:06:08.523106
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e50154680a5c"
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""non-nullbale slack bot id in channel config
|
||||
|
||||
Revision ID: f7a894b06d02
|
||||
Revises: 9f696734098f
|
||||
Create Date: 2024-12-06 12:55:42.845723
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7a894b06d02"
|
||||
down_revision = "9f696734098f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Delete all rows with null slack_bot_id
|
||||
op.execute("DELETE FROM slack_channel_config WHERE slack_bot_id IS NULL")
|
||||
|
||||
# Make slack_bot_id non-nullable
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Make slack_bot_id nullable again
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
@@ -1,3 +1,3 @@
|
||||
These files are for public table migrations when operating with multi tenancy.
|
||||
|
||||
If you are not a Danswer developer, you can ignore this directory entirely.
|
||||
If you are not a Onyx developer, you can ignore this directory entirely.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
@@ -7,8 +8,8 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
from alembic import context
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import PublicBase
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.models import PublicBase
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@@ -37,8 +38,15 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str,
|
||||
type_: str,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
import os
|
||||
|
||||
__version__ = os.environ.get("DANSWER_VERSION", "") or "0.3-dev"
|
||||
@@ -1,656 +0,0 @@
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
from email_validator import EmailUndeliverableError
|
||||
from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import exceptions
|
||||
from fastapi_users import FastAPIUsers
|
||||
from fastapi_users import models
|
||||
from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import JWTStrategy
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.openapi import OpenAPIResponseType
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import attributes
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SMTP_PASS
|
||||
from danswer.configs.app_configs import SMTP_PORT
|
||||
from danswer.configs.app_configs import SMTP_SERVER
|
||||
from danswer.configs.app_configs import SMTP_USER
|
||||
from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.auth import SQLAlchemyUserAdminDB
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
if user and user.role == UserRole.ADMIN:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
|
||||
raise ValueError(
|
||||
"User must choose a valid user authentication method: "
|
||||
"disabled, basic, or google_oauth"
|
||||
)
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
if email and email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN):
|
||||
name = email.split("@")[0]
|
||||
if name == DANSWER_API_KEY_PREFIX + UNNAMED_KEY_PLACEHOLDER:
|
||||
return "Unnamed API Key"
|
||||
|
||||
if space_less:
|
||||
return name
|
||||
|
||||
return name.replace("API_KEY__", "API Key: ")
|
||||
|
||||
return email or ""
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
return
|
||||
|
||||
if not email:
|
||||
raise PermissionError("Email must be specified")
|
||||
|
||||
try:
|
||||
email_info = validate_email(email)
|
||||
except EmailUndeliverableError:
|
||||
raise PermissionError("Email is not valid")
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
# normalized emails are now being inserted into the db
|
||||
# we can remove this normalization on read after some time has passed
|
||||
email_info_whitelist = validate_email(email_whitelist)
|
||||
except EmailNotValidError:
|
||||
continue
|
||||
|
||||
# oddly, normalization does not include lowercasing the user part of the
|
||||
# email address ... which we want to allow
|
||||
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
|
||||
return
|
||||
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
if VALID_EMAIL_DOMAINS:
|
||||
if email.count("@") != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is not valid",
|
||||
)
|
||||
domain = email.split("@")[-1]
|
||||
if domain not in VALID_EMAIL_DOMAINS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email domain is not valid",
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return "public"
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = "Danswer Email Verification"
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
|
||||
body = MIMEText(f"Click the following link to verify your email address: {link}")
|
||||
msg.attach(body)
|
||||
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
# If credentials fails with gmail, check (You need an app password, not just the basic email password)
|
||||
# https://support.google.com/accounts/answer/185833?sjid=8512343437447396151-NA
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
|
||||
async def create(
|
||||
self,
|
||||
user_create: schemas.UC | UserCreate,
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0 or user_create.email in get_default_admin_user_emails():
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
return user
|
||||
|
||||
async def on_after_login(
|
||||
self,
|
||||
user: User,
|
||||
request: Request | None = None,
|
||||
response: Response | None = None,
|
||||
) -> None:
|
||||
if response is None or not MULTI_TENANT:
|
||||
return
|
||||
|
||||
tenant_id = get_tenant_id_for_email(user.email)
|
||||
|
||||
tenant_token = jwt.encode(
|
||||
{"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256"
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key="tenant_details",
|
||||
value=tenant_token,
|
||||
httponly=True,
|
||||
secure=WEB_DOMAIN.startswith("https"),
|
||||
samesite="lax",
|
||||
)
|
||||
|
||||
async def oauth_callback(
|
||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||
oauth_name: str,
|
||||
access_token: str,
|
||||
account_id: str,
|
||||
account_email: str,
|
||||
expires_at: Optional[int] = None,
|
||||
refresh_token: Optional[str] = None,
|
||||
request: Optional[Request] = None,
|
||||
*,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> models.UOAP:
|
||||
# Get tenant_id from mapping table
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
# Print a list of tables in the current database session
|
||||
verify_email_in_whitelist(account_email, tenant_id)
|
||||
verify_email_domain(account_email)
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
self.database = tenant_user_db
|
||||
|
||||
oauth_account_dict = {
|
||||
"oauth_name": oauth_name,
|
||||
"access_token": access_token,
|
||||
"account_id": account_id,
|
||||
"account_email": account_email,
|
||||
"expires_at": expires_at,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
try:
|
||||
# Attempt to get user by OAuth account
|
||||
user = await self.get_by_oauth_account(oauth_name, account_id)
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = await self.get_by_email(account_email)
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
|
||||
# If user not found by OAuth account or email, create a new user
|
||||
except exceptions.UserNotExists:
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
"email": account_email,
|
||||
"hashed_password": self.password_helper.hash(password),
|
||||
"is_verified": is_verified_by_default,
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user, existing_oauth_account, oauth_account_dict
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(
|
||||
user, update_dict={"oidc_expiry": oidc_expiry}
|
||||
)
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.has_web_login: # type: ignore
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"has_web_login": True,
|
||||
},
|
||||
)
|
||||
user.is_verified = is_verified_by_default
|
||||
user.has_web_login = True # type: ignore
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if (
|
||||
user.oidc_expiry is not None # type: ignore
|
||||
and not TRACK_EXTERNAL_IDP_EXPIRY
|
||||
):
|
||||
await self.user_db.update(user, {"oidc_expiry": None})
|
||||
user.oidc_expiry = None # type: ignore
|
||||
|
||||
if token:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.notice(f"User {user.id} has registered.")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.SIGN_UP,
|
||||
data={"action": "create"},
|
||||
user_id=str(user.id),
|
||||
)
|
||||
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
verify_email_domain(user.email)
|
||||
|
||||
logger.notice(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}"
|
||||
)
|
||||
|
||||
send_user_verification_email(user.email, token)
|
||||
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[User]:
|
||||
email = credentials.username
|
||||
|
||||
# Get tenant_id from mapping table
|
||||
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
if not tenant_id:
|
||||
# User not found in mapping
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
# Create a tenant-specific session
|
||||
async with get_async_session_with_tenant(tenant_id) as tenant_session:
|
||||
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
|
||||
tenant_session, User
|
||||
)
|
||||
self.user_db = tenant_user_db
|
||||
|
||||
# Proceed with authentication
|
||||
try:
|
||||
user = await self.get_by_email(email)
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
has_web_login = attributes.get_attribute(user, "has_web_login")
|
||||
|
||||
if not has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
|
||||
verified, updated_password_hash = self.password_helper.verify_and_update(
|
||||
credentials.password, user.hashed_password
|
||||
)
|
||||
if not verified:
|
||||
return None
|
||||
|
||||
if updated_password_hash is not None:
|
||||
await self.user_db.update(
|
||||
user, {"hashed_password": updated_password_hash}
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
) -> AsyncGenerator[UserManager, None]:
|
||||
yield UserManager(user_db)
|
||||
|
||||
|
||||
cookie_transport = CookieTransport(
|
||||
cookie_max_age=SESSION_EXPIRE_TIME_SECONDS,
|
||||
cookie_secure=WEB_DOMAIN.startswith("https"),
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy:
|
||||
return JWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt" if MULTI_TENANT else "database",
|
||||
transport=cookie_transport,
|
||||
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
def get_logout_router(
|
||||
self,
|
||||
backend: AuthenticationBackend,
|
||||
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Provide a router for logout only for OAuth/OIDC Flows.
|
||||
This way the login router does not need to be included
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
get_current_user_token = self.authenticator.current_user_token(
|
||||
active=True, verified=requires_verification
|
||||
)
|
||||
|
||||
logout_responses: OpenAPIResponseType = {
|
||||
**{
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "Missing token or inactive user."
|
||||
}
|
||||
},
|
||||
**backend.transport.get_openapi_logout_responses_success(),
|
||||
}
|
||||
|
||||
@router.post(
|
||||
"/logout", name=f"auth:{backend.name}.logout", responses=logout_responses
|
||||
)
|
||||
async def logout(
|
||||
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
) -> Response:
|
||||
user, token = user_token
|
||||
return await backend.logout(strategy, user, token)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
get_user_manager, [auth_backend]
|
||||
)
|
||||
|
||||
|
||||
# NOTE: verified=REQUIRE_EMAIL_VERIFICATION is not used here since we
|
||||
# take care of that in `double_check_user` ourself. This is needed, since
|
||||
# we want the /me endpoint to still return a user even if they are not
|
||||
# yet verified, so that the frontend knows they exist
|
||||
optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True)
|
||||
|
||||
|
||||
async def optional_user_(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> User | None:
|
||||
"""NOTE: `request` and `db_session` are not used here, but are included
|
||||
for the EE version of this function."""
|
||||
return user
|
||||
|
||||
|
||||
async def optional_user(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(optional_fastapi_current_user),
|
||||
) -> User | None:
|
||||
versioned_fetch_user = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "optional_user_"
|
||||
)
|
||||
return await versioned_fetch_user(request, user, db_session)
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
include_expired: bool = False,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_with_expired_token(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(user, include_expired=True)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_curator_or_admin_user(
|
||||
user: User | None = Depends(current_user),
|
||||
) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated or lacks role information.",
|
||||
)
|
||||
|
||||
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
|
||||
if user.role not in allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not a curator or admin.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Danswer MIT
|
||||
return []
|
||||
@@ -1,494 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
from celery import current_task
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import beat_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.background.update import get_all_tenant_ids
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object(
|
||||
"danswer.background.celery.celeryconfig"
|
||||
) # Load configuration from 'celeryconfig.py'
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def celery_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
"""We handle this signal in order to remove completed tasks
|
||||
from their respective tasksets. This allows us to track the progress of document set
|
||||
and user group syncs.
|
||||
|
||||
This function runs after any task completes (both success and failure)
|
||||
Note that this signal does not fire on a task that failed to complete and is going
|
||||
to be retried.
|
||||
"""
|
||||
if not task:
|
||||
return
|
||||
|
||||
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
||||
|
||||
if state not in READY_STATES:
|
||||
return
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
r.srem(rcd.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
r.srem(rcp.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# decide some initial startup settings based on the celery worker's hostname
|
||||
# (set at the command line)
|
||||
hostname = sender.hostname
|
||||
if hostname.startswith("light"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
elif hostname.startswith("heavy"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
else:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
time_start = time.monotonic()
|
||||
logger.info("Redis: Readiness check starting.")
|
||||
while True:
|
||||
try:
|
||||
if r.ping():
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Redis: Readiness check succeeded. Continuing...")
|
||||
|
||||
if not celery_is_worker_primary(sender):
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
time_start = time.monotonic()
|
||||
while True:
|
||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||
break
|
||||
|
||||
time.monotonic()
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Primary worker was not ready within the timeout. "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||
return
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker
|
||||
r = get_redis_client()
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
sender.primary_worker_lock = lock
|
||||
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
task_logger.info("worker_ready signal received.")
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock = sender.primary_worker_lock
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
|
||||
|
||||
class CeleryTaskPlainFormatter(PlainFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class CeleryTaskColoredFormatter(ColoredFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
# TODO: could unhardcode format and colorize and accept these as options from
|
||||
# celery's config
|
||||
|
||||
# reformats celery's worker logger
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
||||
root_formatter = ColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_handler.setFormatter(root_formatter)
|
||||
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
||||
|
||||
if logfile:
|
||||
root_file_handler = logging.FileHandler(logfile)
|
||||
root_file_formatter = PlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_file_handler.setFormatter(root_file_formatter)
|
||||
root_logger.addHandler(root_file_handler)
|
||||
|
||||
root_logger.setLevel(loglevel)
|
||||
|
||||
# reformats celery's task logger
|
||||
task_formatter = CeleryTaskColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
||||
task_handler.setFormatter(task_formatter)
|
||||
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
||||
|
||||
if logfile:
|
||||
task_file_handler = logging.FileHandler(logfile)
|
||||
task_file_formatter = CeleryTaskPlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_file_handler.setFormatter(task_file_formatter)
|
||||
task_logger.addHandler(task_file_handler)
|
||||
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
Use the task_logger in this class to avoid double logging.
|
||||
|
||||
This cannot be done inside a regular beat task because it must run on schedule and
|
||||
a queue of existing work would starve the task from running.
|
||||
"""
|
||||
|
||||
# it's unclear to me whether using the hub's timer or the bootstep timer is better
|
||||
requires = {"celery.worker.components:Hub"}
|
||||
|
||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
||||
self.task_tref = None
|
||||
|
||||
def start(self, worker: Any) -> None:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
# Access the worker's event loop (hub)
|
||||
hub = worker.consumer.controller.hub
|
||||
|
||||
# Schedule the periodic task
|
||||
self.task_tref = hub.call_repeatedly(
|
||||
self.interval, self.run_periodic_task, worker
|
||||
)
|
||||
task_logger.info("Scheduled periodic task with hub.")
|
||||
|
||||
def run_periodic_task(self, worker: Any) -> None:
|
||||
try:
|
||||
if not worker.primary_worker_lock:
|
||||
return
|
||||
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
return
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
lock: redis.lock.Lock = worker.primary_worker_lock
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"Full acquisition of primary worker lock. "
|
||||
"Reasons could be computer sleep or a clock change."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
task_logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
task_logger.error("Primary worker lock: Acquire failed!")
|
||||
raise TimeoutError("Primary worker lock could not be acquired!")
|
||||
|
||||
worker.primary_worker_lock = lock
|
||||
except Exception:
|
||||
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
|
||||
|
||||
def stop(self, worker: Any) -> None:
|
||||
# Cancel the scheduled task when the worker stops
|
||||
if self.task_tref:
|
||||
self.task_tref.cancel()
|
||||
task_logger.info("Canceled periodic task with hub.")
|
||||
|
||||
|
||||
celery_app.steps["worker"].add(HubPeriodicTask)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.periodic",
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
]
|
||||
)
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_prune_task_2",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
# Build the celery beat schedule dynamically
|
||||
beat_schedule = {}
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
for task in tasks_to_schedule:
|
||||
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||
beat_schedule[task_name] = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"options": task["options"],
|
||||
"args": (tenant_id,), # Must pass tenant_id as an argument
|
||||
}
|
||||
|
||||
# Include any existing beat schedules
|
||||
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||
beat_schedule.update(existing_beat_schedule)
|
||||
|
||||
# Update the Celery app configuration once
|
||||
celery_app.conf.beat_schedule = beat_schedule
|
||||
@@ -1,492 +0,0 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
PREFIX = "base"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int):
|
||||
self._id: int = id
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
return f"{self.PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def fence_key(self) -> str:
|
||||
# example: documentset_fence_1
|
||||
return f"{self.FENCE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
# example: documentset_taskset_1
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[2])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(self._id, current_only=False)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
if not global_version.is_ee_version():
|
||||
return 0
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(self._id)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
"""This class differs from the default in that the taskset used spans
|
||||
all connectors and is not per connector."""
|
||||
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
|
||||
@classmethod
|
||||
def get_taskset_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.TASKSET_PREFIX
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
"""Notice that this is intentionally reusing the same taskset for all
|
||||
connector syncs"""
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorDeletion(RedisObjectHelper):
|
||||
PREFIX = "connectordeletion"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorPruning(RedisObjectHelper):
|
||||
"""Celery will kick off a long running generator task to crawl the connector and
|
||||
find any missing docs, which will each then get a new cleanup task. The progress of
|
||||
those tasks will then be monitored to completion.
|
||||
|
||||
Example rough happy path order:
|
||||
Check connectorpruning_fence_1
|
||||
Send generator task with id connectorpruning+generator_1_{uuid}
|
||||
|
||||
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
|
||||
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
|
||||
in taskset connectorpruning_taskset_1
|
||||
on completion, generator sets connectorpruning_generator_complete_1
|
||||
|
||||
celery postrun removes subtasks from taskset
|
||||
monitor beat task cleans up when taskset reaches 0 items
|
||||
"""
|
||||
|
||||
PREFIX = "connectorpruning"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
|
||||
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
SUBTASK_PREFIX = PREFIX + "+sub"
|
||||
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # a signal that contains generator progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # a signal that the generator has finished
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
"""id: the cc_pair_id of the connector credential pair"""
|
||||
|
||||
super().__init__(id)
|
||||
self.documents_to_prune: set[str] = set()
|
||||
|
||||
@property
|
||||
def generator_task_id_prefix(self) -> str:
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_progress_key(self) -> str:
|
||||
# example: connectorpruning_generator_progress_1
|
||||
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_complete_key(self) -> str:
|
||||
# example: connectorpruning_generator_complete_1
|
||||
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def subtask_id_prefix(self) -> str:
|
||||
return f"{self.SUBTASK_PREFIX}_{self._id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
for doc_id in self.documents_to_prune:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=self._id, db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair_id {self._id} does not exist.")
|
||||
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
used to implement task prioritization.
|
||||
This operation is not atomic."""
|
||||
total_length = 0
|
||||
for i in range(len(DanswerCeleryPriority)):
|
||||
queue_name = queue
|
||||
if i > 0:
|
||||
queue_name += CELERY_SEPARATOR
|
||||
queue_name += str(i)
|
||||
|
||||
length = r.llen(queue_name)
|
||||
total_length += cast(int, length)
|
||||
|
||||
return total_length
|
||||
@@ -1,9 +0,0 @@
|
||||
"""Entry point for running celery worker / celery beat."""
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
celery_app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.celery_app", "celery_app"
|
||||
)
|
||||
@@ -1,133 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_deletion_status(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> TaskQueueState | None:
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
"""
|
||||
cc_pair = get_connector_credential_pair(
|
||||
connector_id=connector_id, credential_id=credential_id, db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
r = get_redis_client()
|
||||
if not r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
return TaskQueueState(
|
||||
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
|
||||
)
|
||||
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
|
||||
if not deletion_task:
|
||||
return None
|
||||
|
||||
return DeletionAttemptSnapshot(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
status=deletion_task.status,
|
||||
)
|
||||
|
||||
|
||||
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
doc_batch_generator = None
|
||||
if isinstance(runnable_connector, IdConnector):
|
||||
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
if doc_batch_generator:
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
if progress_callback:
|
||||
progress_callback(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
|
||||
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
||||
"""Checks to see if we're listening to the named queue"""
|
||||
|
||||
# how to get a list of queues this worker is listening to
|
||||
# https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime
|
||||
queue_names = list(worker.app.amqp.queues.consume_from.keys())
|
||||
for queue_name in queue_names:
|
||||
if queue_name == name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_is_worker_primary(worker: Any) -> bool:
|
||||
"""There are multiple approaches that could be taken to determine if a celery worker
|
||||
is 'primary', as defined by us. But the way we do it is to check the hostname set
|
||||
for the celery worker, which can be done either in celeryconfig.py or on the
|
||||
command line with '--hostname'."""
|
||||
hostname = worker.hostname
|
||||
if hostname.startswith("light"):
|
||||
return False
|
||||
|
||||
if hostname.startswith("heavy"):
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -1,113 +0,0 @@
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def check_for_connector_deletion_task(tenant_id: str | None) -> None:
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
Returns None if no syncing is required.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
# we need to refresh the state of the object inside the fence
|
||||
# to avoid a race condition with db.commit/fence deletion
|
||||
# at the end of this taskset
|
||||
try:
|
||||
db_session.refresh(cc_pair)
|
||||
except ObjectDeletedError:
|
||||
return None
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcd.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
@@ -1,251 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_prune_task_2",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_prune_task_2(tenant_id: str | None) -> None:
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
tasks_created = ccpair_pruning_generator_task_creation_helper(
|
||||
cc_pair, db_session, tenant_id, r, lock_beat
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Pruning started: cc_pair_id={cc_pair.id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def ccpair_pruning_generator_task_creation_helper(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
"""Returns an int if pruning is triggered.
|
||||
The int represents the number of prune tasks generated (in this case, only one
|
||||
because the task is a long running generator task.)
|
||||
Returns None if no pruning is triggered (due to not being needed or
|
||||
other reasons such as simultaneous pruning restrictions.
|
||||
|
||||
Checks for scheduling related conditions, then delegates the rest of the checks to
|
||||
try_creating_prune_generator_task.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
# skip pruning if no prune frequency is set
|
||||
# pruning can still be forced via the API which will run a pruning task directly
|
||||
if not cc_pair.connector.prune_freq:
|
||||
return None
|
||||
|
||||
# skip pruning if the next scheduled prune time hasn't been reached yet
|
||||
last_pruned = cc_pair.last_pruned
|
||||
if not last_pruned:
|
||||
# if never pruned, use the connector time created as the last_pruned time
|
||||
last_pruned = cc_pair.connector.time_created
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return None
|
||||
|
||||
return try_creating_prune_generator_task(cc_pair, db_session, r, tenant_id)
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger prunes immediately.
|
||||
"""
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
return None
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair.id)
|
||||
|
||||
# skip pruning if already pruning
|
||||
if r.exists(rcp.fence_key):
|
||||
return None
|
||||
|
||||
# skip pruning if the cc_pair is deleting
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
|
||||
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
"connector_pruning_generator_task",
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcp.fence_key, 1)
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
|
||||
def connector_pruning_generator_task(
|
||||
connector_id: int, credential_id: int, tenant_id: str | None
|
||||
) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"ccpair not found for {connector_id} {credential_id}"
|
||||
)
|
||||
return
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair.id)
|
||||
|
||||
# Define the callback function
|
||||
def redis_increment_callback(amount: int) -> None:
|
||||
r.incrby(rcp.generator_progress_key, amount)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector, redis_increment_callback
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning set collected: "
|
||||
f"cc_pair_id={cc_pair.id} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)} "
|
||||
f"doc_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
rcp.documents_to_prune = set(doc_ids_to_remove)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = rcp.generate_tasks(
|
||||
celery_app, db_session, r, None, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorPruning.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(rcp.generator_complete_key, tasks_generated)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning for connector id {connector_id}."
|
||||
)
|
||||
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.fence_key)
|
||||
raise e
|
||||
@@ -1,126 +0,0 @@
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_connector_count
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def document_by_cc_pair_cleanup_task(
|
||||
self: Task,
|
||||
document_id: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> bool:
|
||||
"""A lightweight subtask used to clean up document to cc pair relationships.
|
||||
Created by connection deletion and connector pruning parent tasks."""
|
||||
|
||||
"""
|
||||
To delete a connector / credential pair:
|
||||
(1) find all documents associated with connector / credential pair where there
|
||||
this the is only connector / credential pair that has indexed it
|
||||
(2) delete all documents from document stores
|
||||
(3) delete all entries from postgres
|
||||
(4) find all documents associated with connector / credential pair where there
|
||||
are multiple connector / credential pairs that have indexed it
|
||||
(5) update document store entries to remove access associated with the
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
action = "skip"
|
||||
chunks_affected = 0
|
||||
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
# count == 1 means this is the only remaining cc_pair reference to the doc
|
||||
# delete it from vespa and the db
|
||||
action = "delete"
|
||||
|
||||
chunks_affected = document_index.delete_single(document_id)
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
)
|
||||
elif count > 1:
|
||||
action = "update"
|
||||
|
||||
# count > 1 means the document still has cc_pair references
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# the below functions do not include cc_pairs being deleted.
|
||||
# i.e. they will correctly omit access for the current cc_pair
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
fields = VespaDocumentFields(
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = document_index.update_single(
|
||||
document_id, fields=fields
|
||||
)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
else:
|
||||
pass
|
||||
|
||||
task_logger.info(
|
||||
f"document_id={document_id} refcount={count} action={action} chunks={chunks_affected}"
|
||||
)
|
||||
db_session.commit()
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
@@ -1,621 +0,0 @@
|
||||
import traceback
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import celery_get_queue_length
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_ccpair_as_pruned
|
||||
from danswer.db.connector_credential_pair import add_deletion_failure_message
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import count_documents_by_needs_sync
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def check_for_vespa_sync_task(tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id)
|
||||
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
try_generate_document_set_sync_tasks(
|
||||
document_set, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
# check if any user groups are not synced
|
||||
if global_version.is_ee_version():
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_groups"
|
||||
)
|
||||
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
for usergroup in user_groups:
|
||||
try_generate_user_group_sync_tasks(
|
||||
usergroup, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
# Always exceptions on the MIT version, which is expected
|
||||
# We shouldn't actually get here if the ee version check works
|
||||
pass
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
return None
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
stale_doc_count = count_documents_by_needs_sync(db_session)
|
||||
if stale_doc_count == 0:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||
)
|
||||
|
||||
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
|
||||
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
total_tasks_generated = 0
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
if tasks_generated is None:
|
||||
continue
|
||||
|
||||
if tasks_generated == 0:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for single cc_pair. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
total_tasks_generated += tasks_generated
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
|
||||
return total_tasks_generated
|
||||
|
||||
|
||||
def try_generate_document_set_sync_tasks(
|
||||
document_set: DocumentSet,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(document_set.id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if r.exists(rds.fence_key):
|
||||
return None
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
db_session.refresh(document_set)
|
||||
if document_set.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rds.taskset_key)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
tasks_generated = rds.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks finished. "
|
||||
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rds.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def try_generate_user_group_sync_tasks(
|
||||
usergroup: UserGroup,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(usergroup.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
db_session.refresh(usergroup)
|
||||
if usergroup.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rug.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||
)
|
||||
tasks_generated = rug.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks finished. "
|
||||
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rug.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def monitor_connector_taskset(r: Redis) -> None:
|
||||
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
|
||||
task_logger.info(
|
||||
f"Stale document sync progress: remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count == 0:
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
if document_set_id is None:
|
||||
task_logger.warning(f"could not parse document set id from {fence_key}")
|
||||
return
|
||||
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
fence_value = r.get(rds.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if document_set:
|
||||
if not document_set.connector_credential_pairs:
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
|
||||
r.delete(rds.taskset_key)
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcd.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"monitor_connector_deletion_taskset - cc_pair_id not found: cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair.id, error_message)
|
||||
task_logger.exception(
|
||||
f"Failed to run connector_deletion. "
|
||||
f"cc_pair_id={cc_pair_id} connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Successfully deleted cc_pair: "
|
||||
f"cc_pair_id={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"docs_deleted={initial_count}"
|
||||
)
|
||||
|
||||
r.delete(rcd.taskset_key)
|
||||
r.delete(rcd.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id is None:
|
||||
task_logger.warning(
|
||||
f"monitor_connector_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcp.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
generator_value = r.get(rcp.generator_complete_key)
|
||||
if generator_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, generator_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcp.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(cc_pair_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.fence_key)
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
"""
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat: redis.lock.Lock = r.lock(
|
||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
# print current queue lengths
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r)
|
||||
n_sync = celery_get_queue_length(
|
||||
DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery
|
||||
)
|
||||
n_deletion = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} sync={n_sync} deletion={n_deletion} pruning={n_pruning}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
monitor_connector_deletion_taskset(key_bytes, r)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, tenant_id: str | None
|
||||
) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# document set sync
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
# User group sync
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
update_request = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
document_index.update(update_requests=[update_request])
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
@@ -1,574 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import dask
|
||||
from dask.distributed import Client
|
||||
from dask.distributed import Future
|
||||
from distributed import LocalCluster
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
from danswer.background.indexing.job_client import SimpleJob
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If the indexing dies, it's most likely due to resource constraints,
|
||||
# restarting just delays the eventual failure, not useful to the user
|
||||
dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
||||
|
||||
_UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
"Stopped mid run, likely due to the background process being killed"
|
||||
)
|
||||
|
||||
|
||||
def _should_create_new_indexing(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
connector = cc_pair.connector
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if (
|
||||
not cc_pair.status.is_active()
|
||||
or connector.id == 0
|
||||
or connector.source == DocumentSource.INGESTION_API
|
||||
):
|
||||
return False
|
||||
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
|
||||
# Only one scheduled/ongoing job per connector at a time
|
||||
# this prevents cases where
|
||||
# (1) the "latest" index_attempt is scheduled so we show
|
||||
# that in the UI despite another index_attempt being in-progress
|
||||
# (2) multiple scheduled index_attempts at a time
|
||||
if (
|
||||
last_index.status == IndexingStatus.NOT_STARTED
|
||||
or last_index.status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
|
||||
|
||||
def _mark_run_failed(
|
||||
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
) -> None:
|
||||
"""Marks the `index_attempt` row as failed + updates the `
|
||||
connector_credential_pair` to reflect that the run failed"""
|
||||
logger.warning(
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
db_session=db_session,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
|
||||
|
||||
"""Main funcs"""
|
||||
|
||||
|
||||
def create_indexing_jobs(
|
||||
existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None
|
||||
) -> None:
|
||||
"""Creates new indexing jobs for each connector / credential pair which is:
|
||||
1. Enabled
|
||||
2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
ongoing: set[tuple[int | None, int]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
if attempt is None:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
||||
"indexing jobs"
|
||||
)
|
||||
continue
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.search_settings_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in all_connector_credential_pairs:
|
||||
for search_settings_instance in search_settings:
|
||||
# Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
if (cc_pair.id, search_settings_instance.id) in ongoing:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
tenant_id: str | None,
|
||||
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
# clean up completed jobs
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
for attempt_id, job in existing_jobs.items():
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
del existing_jobs_copy[attempt_id]
|
||||
|
||||
if not index_attempt:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
||||
"up indexing jobs"
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
index_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
or job.status == "error"
|
||||
):
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
|
||||
# clean up in-progress jobs that were never completed
|
||||
try:
|
||||
connectors = fetch_connectors(db_session)
|
||||
for connector in connectors:
|
||||
in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
connector.id, db_session
|
||||
)
|
||||
|
||||
for index_attempt in in_progress_indexing_attempts:
|
||||
if index_attempt.id in existing_jobs:
|
||||
# If index attempt is canceled, stop the run
|
||||
if index_attempt.status == IndexingStatus.FAILED:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
# check to see if the job has been updated in last `timeout_hours` hours, if not
|
||||
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# on the fact that the `time_updated` field is constantly updated every
|
||||
# batch of documents indexed
|
||||
current_db_time = get_db_current_time(db_session=db_session)
|
||||
time_since_update = current_db_time - index_attempt.time_updated
|
||||
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason="Indexing run frozen - no updates in the last three hours. "
|
||||
"The run will be re-attempted at next scheduled indexing time.",
|
||||
)
|
||||
else:
|
||||
# If job isn't known, simply mark it as failed
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
except ProgrammingError:
|
||||
logger.debug(f"No Connector Table exists for: {tenant_id}")
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def kickoff_indexing_jobs(
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
client: Client | SimpleJobClient,
|
||||
secondary_client: Client | SimpleJobClient,
|
||||
tenant_id: str | None,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
|
||||
current_session = get_session_with_tenant(tenant_id)
|
||||
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
with current_session as db_session:
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.search_settings)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
if attempt.id not in existing_jobs
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
|
||||
|
||||
if not new_indexing_attempts:
|
||||
return existing_jobs
|
||||
|
||||
indexing_attempt_count = 0
|
||||
|
||||
primary_client_full = False
|
||||
secondary_client_full = False
|
||||
for attempt, search_settings in new_indexing_attempts:
|
||||
if primary_client_full and secondary_client_full:
|
||||
break
|
||||
|
||||
use_secondary_index = (
|
||||
search_settings.status == IndexModelStatus.FUTURE
|
||||
if search_settings is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector_credential_pair.connector is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
)
|
||||
with current_session as db_session:
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Connector is null"
|
||||
)
|
||||
continue
|
||||
if attempt.connector_credential_pair.credential is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
)
|
||||
with current_session as db_session:
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Credential is null"
|
||||
)
|
||||
continue
|
||||
|
||||
if not use_secondary_index:
|
||||
if not primary_client_full:
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
tenant_id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not run:
|
||||
primary_client_full = True
|
||||
else:
|
||||
if not secondary_client_full:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
tenant_id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not run:
|
||||
secondary_client_full = True
|
||||
|
||||
if run:
|
||||
if indexing_attempt_count == 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
|
||||
)
|
||||
|
||||
indexing_attempt_count += 1
|
||||
secondary_str = " (secondary index)" if use_secondary_index else ""
|
||||
logger.info(
|
||||
f"Indexing dispatched{secondary_str}: "
|
||||
f"attempt_id={attempt.id} "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.credential_id}'"
|
||||
)
|
||||
existing_jobs_copy[attempt.id] = run
|
||||
|
||||
if indexing_attempt_count > 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch results: "
|
||||
f"initial_pending={len(new_indexing_attempts)} "
|
||||
f"started={indexing_attempt_count} "
|
||||
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
with get_session_with_tenant(tenant_id="public") as session:
|
||||
result = session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
||||
)
|
||||
)
|
||||
tenant_ids = [row[0] for row in result]
|
||||
|
||||
valid_tenants = [
|
||||
tenant
|
||||
for tenant in tenant_ids
|
||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
return valid_tenants
|
||||
|
||||
|
||||
def update_loop(
|
||||
delay: int = 10,
|
||||
num_workers: int = NUM_INDEXING_WORKERS,
|
||||
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
) -> None:
|
||||
if not MULTI_TENANT:
|
||||
# We can use this function as we are certain only the public schema exists
|
||||
# (explicitly for the non-`MULTI_TENANT` case)
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
logger.notice("First inference complete.")
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
if DASK_JOB_CLIENT_ENABLED:
|
||||
cluster_primary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
cluster_secondary = LocalCluster(
|
||||
n_workers=num_secondary_workers,
|
||||
threads_per_worker=1,
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
client_primary = Client(cluster_primary)
|
||||
client_secondary = Client(cluster_secondary)
|
||||
if LOG_LEVEL.lower() == "debug":
|
||||
client_primary.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
|
||||
existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
|
||||
|
||||
logger.notice("Startup complete. Waiting for indexing jobs...")
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||
|
||||
if existing_jobs:
|
||||
logger.debug(
|
||||
"Found existing indexing jobs: "
|
||||
f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
|
||||
)
|
||||
|
||||
try:
|
||||
tenants = get_all_tenant_ids()
|
||||
|
||||
for tenant_id in tenants:
|
||||
try:
|
||||
logger.debug(
|
||||
f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_to_expire = check_index_swap(db_session=db_session)
|
||||
|
||||
if index_to_expire and tenant_id and MULTI_TENANT:
|
||||
VespaIndex.delete_entries_by_tenant_id(
|
||||
tenant_id=tenant_id,
|
||||
index_name=index_to_expire.index_name,
|
||||
)
|
||||
|
||||
if not MULTI_TENANT:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice(
|
||||
"Running a first inference to warm up embedding model"
|
||||
)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(embedding_model=embedding_model)
|
||||
logger.notice("First inference complete.")
|
||||
|
||||
tenant_jobs = existing_jobs.get(tenant_id, {})
|
||||
|
||||
tenant_jobs = cleanup_indexing_jobs(
|
||||
existing_jobs=tenant_jobs, tenant_id=tenant_id
|
||||
)
|
||||
create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id)
|
||||
tenant_jobs = kickoff_indexing_jobs(
|
||||
existing_jobs=tenant_jobs,
|
||||
client=client_primary,
|
||||
secondary_client=client_secondary,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
existing_jobs[tenant_id] = tenant_jobs
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to process tenant {tenant_id or 'default'}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run update due to {e}")
|
||||
|
||||
sleep_time = delay - (time.time() - start)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
def update__main() -> None:
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
# initialize the Postgres connection pool
|
||||
SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
|
||||
|
||||
logger.notice("Starting indexing service")
|
||||
update_loop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
update__main()
|
||||
@@ -1,168 +0,0 @@
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
# This one is using the combined content of all the chunks of the section
|
||||
# In default settings, this is the same as just the content of base chunk
|
||||
content=inference_section.combined_content,
|
||||
blurb=inference_section.center_chunk.blurb,
|
||||
semantic_identifier=inference_section.center_chunk.semantic_identifier,
|
||||
source_type=inference_section.center_chunk.source_type,
|
||||
metadata=inference_section.center_chunk.metadata,
|
||||
updated_at=inference_section.center_chunk.updated_at,
|
||||
link=inference_section.center_chunk.source_links[0]
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
)
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
# Optional id at which we finish processing
|
||||
stop_at_message_id: int | None = None,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
|
||||
all_chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
prefetch_tool_calls=prefetch_tool_calls,
|
||||
)
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
if not all_chat_messages:
|
||||
raise RuntimeError("No messages in Chat Session")
|
||||
|
||||
root_message = all_chat_messages[0]
|
||||
if root_message.parent_message is not None:
|
||||
raise RuntimeError(
|
||||
"Invalid root message, unable to fetch valid chat message sequence"
|
||||
)
|
||||
|
||||
current_message: ChatMessage | None = root_message
|
||||
while current_message is not None:
|
||||
child_msg = current_message.latest_child_message
|
||||
|
||||
# Break if at the end of the chain
|
||||
# or have reached the `final_id` of the submitted message
|
||||
if not child_msg or (
|
||||
stop_at_message_id and current_message.id == stop_at_message_id
|
||||
):
|
||||
break
|
||||
current_message = id_to_msg.get(child_msg)
|
||||
|
||||
if current_message is None:
|
||||
raise RuntimeError(
|
||||
"Invalid message chain,"
|
||||
"could not find next message in the same session"
|
||||
)
|
||||
|
||||
mainline_messages.append(current_message)
|
||||
|
||||
if not mainline_messages:
|
||||
raise RuntimeError("Could not trace chat message history")
|
||||
|
||||
return mainline_messages[-1], mainline_messages[:-1]
|
||||
|
||||
|
||||
def combine_message_chain(
|
||||
messages: list[ChatMessage] | list[PreviousMessage],
|
||||
token_limit: int,
|
||||
msg_limit: int | None = None,
|
||||
) -> str:
|
||||
"""Used for secondary LLM flows that require the chat history,"""
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
if msg_limit is not None:
|
||||
messages = messages[-msg_limit:]
|
||||
|
||||
for message in cast(list[ChatMessage] | list[PreviousMessage], reversed(messages)):
|
||||
message_token_count = message.token_count
|
||||
|
||||
if total_token_count + message_token_count > token_limit:
|
||||
break
|
||||
|
||||
role = message.message_type.value.upper()
|
||||
message_strs.insert(0, f"{role}:\n{message.message}")
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def reorganize_citations(
|
||||
answer: str, citations: list[CitationInfo]
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
"""For a complete, citation-aware response, we want to reorganize the citations so that
|
||||
they are in the order of the documents that were used in the response. This just looks nicer / avoids
|
||||
confusion ("Why is there [7] when only 2 documents are cited?")."""
|
||||
|
||||
# Regular expression to find all instances of [[x]](LINK)
|
||||
pattern = r"\[\[(.*?)\]\]\((.*?)\)"
|
||||
|
||||
all_citation_matches = re.findall(pattern, answer)
|
||||
|
||||
new_citation_info: dict[int, CitationInfo] = {}
|
||||
for citation_match in all_citation_matches:
|
||||
try:
|
||||
citation_num = int(citation_match[0])
|
||||
if citation_num in new_citation_info:
|
||||
continue
|
||||
|
||||
matching_citation = next(
|
||||
iter([c for c in citations if c.citation_num == int(citation_num)]),
|
||||
None,
|
||||
)
|
||||
if matching_citation is None:
|
||||
continue
|
||||
|
||||
new_citation_info[citation_num] = CitationInfo(
|
||||
citation_num=len(new_citation_info) + 1,
|
||||
document_id=matching_citation.document_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Function to replace citations with their new number
|
||||
def slack_link_format(match: re.Match) -> str:
|
||||
link_text = match.group(1)
|
||||
try:
|
||||
citation_num = int(link_text)
|
||||
if citation_num in new_citation_info:
|
||||
link_text = new_citation_info[citation_num].citation_num
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
link_url = match.group(2)
|
||||
return f"[[{link_text}]]({link_url})"
|
||||
|
||||
# Substitute all matches in the input text
|
||||
new_answer = re.sub(pattern, slack_link_format, answer)
|
||||
|
||||
# if any citations weren't parsable, just add them back to be safe
|
||||
for citation in citations:
|
||||
if citation.citation_num not in new_citation_info:
|
||||
new_citation_info[citation.citation_num] = citation
|
||||
|
||||
return new_answer, list(new_citation_info.values())
|
||||
@@ -1,24 +0,0 @@
|
||||
input_prompts:
|
||||
- id: -5
|
||||
prompt: "Elaborate"
|
||||
content: "Elaborate on the above, give me a more in depth explanation."
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -4
|
||||
prompt: "Reword"
|
||||
content: "Help me rewrite the following politely and concisely for professional communication:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -3
|
||||
prompt: "Email"
|
||||
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -2
|
||||
prompt: "Debug"
|
||||
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
@@ -1,185 +0,0 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
"""This contains the minimal set information for the LLM portion including citations"""
|
||||
|
||||
document_id: str
|
||||
content: str
|
||||
blurb: str
|
||||
semantic_identifier: str
|
||||
source_type: DocumentSource
|
||||
metadata: dict[str, str | list[str]]
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs):
|
||||
rephrased_query: str | None = None
|
||||
predicted_flow: QueryFlow | None
|
||||
predicted_search: SearchType | None
|
||||
applied_source_filters: list[DocumentSource] | None
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
initial_dict["applied_time_cutoff"] = (
|
||||
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
||||
)
|
||||
|
||||
return initial_dict
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
data["stop_reason"] = self.stop_reason.name
|
||||
return data
|
||||
|
||||
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int]
|
||||
|
||||
|
||||
class FinalUsedContextDocsResponse(BaseModel):
|
||||
final_context_docs: list[LlmDoc]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
relevant: bool
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class SectionRelevancePiece(RelevanceAnalysis):
|
||||
"""LLM analysis mapped to an Inference Section"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int # ID of the center chunk for a given inference section
|
||||
|
||||
|
||||
class DocumentRelevance(BaseModel):
|
||||
"""Contains all relevance information for a given search"""
|
||||
|
||||
relevance_summaries: dict[str, RelevanceAnalysis]
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
|
||||
|
||||
# An intermediate representation of citations, later translated into
|
||||
# a mapping of the citation [n] number to SearchDoc
|
||||
class CitationInfo(BaseModel):
|
||||
citation_num: int
|
||||
document_id: str
|
||||
|
||||
|
||||
class AllCitations(BaseModel):
|
||||
citations: list[CitationInfo]
|
||||
|
||||
|
||||
# This is a mapping of the citation number to the document index within
|
||||
# the result search doc set
|
||||
class MessageSpecificCitations(BaseModel):
|
||||
citation_map: dict[int, int]
|
||||
|
||||
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
class DanswerContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerContexts(BaseModel):
|
||||
contexts: list[DanswerContext]
|
||||
|
||||
|
||||
class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class QAResponse(SearchResponse, DanswerAnswer):
|
||||
quotes: list[DanswerQuote] | None
|
||||
contexts: list[DanswerContexts] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
class ImageGenerationDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
class CustomToolResponse(BaseModel):
|
||||
response: ToolResultType
|
||||
tool_name: str
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
|
||||
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
@@ -1,93 +0,0 @@
|
||||
# Currently in the UI, each Persona only has one prompt, which is why there are 3 very similar personas defined below.
|
||||
|
||||
personas:
|
||||
# This id field can be left blank for other default personas, however an id 0 persona must exist
|
||||
# this is for DanswerBot to use when tagged in a non-configured channel
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Knowledge"
|
||||
description: >
|
||||
Assistant with access to documents from your Connected Sources.
|
||||
# Default Prompt objects attached to the persona, see prompts.yaml
|
||||
prompts:
|
||||
- "Answer-Question"
|
||||
# Default number of chunks to include as context, set to 0 to disable retrieval
|
||||
# Remove the field to set to the system default number of chunks/tokens to pass to Gen AI
|
||||
# Each chunk is 512 tokens long
|
||||
num_chunks: 10
|
||||
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
|
||||
# if the chunk is useful or not towards the latest user query
|
||||
# This feature can be overriden for all personas via DISABLE_LLM_DOC_RELEVANCE env variable
|
||||
llm_relevance_filter: true
|
||||
# Enable/Disable usage of the LLM to extract query time filters including source type and time range filters
|
||||
llm_filter_extraction: true
|
||||
# Decay documents priority as they age, options are:
|
||||
# - favor_recent (2x base by default, configurable)
|
||||
# - base_decay
|
||||
# - no_decay
|
||||
# - auto (model chooses between favor_recent and base_decay based on user query)
|
||||
recency_bias: "auto"
|
||||
# Default Document Sets for this persona, specified as a list of names here.
|
||||
# If the document set by the name exists, it will be attached to the persona
|
||||
# If the document set by the name does not exist, it will be created as an empty document set with no connectors
|
||||
# The admin can then use the UI to add new connectors to the document set
|
||||
# Example:
|
||||
# document_sets:
|
||||
# - "HR Resources"
|
||||
# - "Engineer Onboarding"
|
||||
# - "Benefits"
|
||||
document_sets: []
|
||||
icon_shape: 23013
|
||||
icon_color: "#6FB1FF"
|
||||
display_priority: 1
|
||||
is_visible: true
|
||||
|
||||
- id: 1
|
||||
name: "General"
|
||||
description: >
|
||||
Assistant with no access to documents. Chat with just the Large Language Model.
|
||||
prompts:
|
||||
- "OnlyLLM"
|
||||
num_chunks: 0
|
||||
llm_relevance_filter: true
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 50910
|
||||
icon_color: "#FF6F6F"
|
||||
display_priority: 0
|
||||
is_visible: true
|
||||
|
||||
- id: 2
|
||||
name: "Paraphrase"
|
||||
description: >
|
||||
Assistant that is heavily constrained and only provides exact quotes from Connected Sources.
|
||||
prompts:
|
||||
- "Paraphrase"
|
||||
num_chunks: 10
|
||||
llm_relevance_filter: true
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 45519
|
||||
icon_color: "#6FFF8D"
|
||||
display_priority: 2
|
||||
is_visible: false
|
||||
|
||||
|
||||
- id: 3
|
||||
name: "Art"
|
||||
description: >
|
||||
Assistant for generating images based on descriptions.
|
||||
prompts:
|
||||
- "ImageGeneration"
|
||||
num_chunks: 0
|
||||
llm_relevance_filter: false
|
||||
llm_filter_extraction: false
|
||||
recency_bias: "no_decay"
|
||||
document_sets: []
|
||||
icon_shape: 234124
|
||||
icon_color: "#9B59B6"
|
||||
image_generation: true
|
||||
display_priority: 3
|
||||
is_visible: true
|
||||
@@ -1,115 +0,0 @@
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
|
||||
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
|
||||
from danswer.prompts.chat_tools import TOOL_TEMPLATE
|
||||
from danswer.prompts.chat_tools import USER_INPUT
|
||||
|
||||
|
||||
class ToolInfo(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class DanswerChatModelOut(BaseModel):
|
||||
model_raw: str
|
||||
action: str
|
||||
action_input: str
|
||||
|
||||
|
||||
def call_tool(
|
||||
model_actions: DanswerChatModelOut,
|
||||
) -> str:
|
||||
raise NotImplementedError("There are no additional tool integrations right now")
|
||||
|
||||
|
||||
def form_user_prompt_text(
|
||||
query: str,
|
||||
tool_text: str | None,
|
||||
hint_text: str | None,
|
||||
user_input_prompt: str = USER_INPUT,
|
||||
tool_less_prompt: str = TOOL_LESS_PROMPT,
|
||||
) -> str:
|
||||
user_prompt = tool_text or tool_less_prompt
|
||||
|
||||
user_prompt += user_input_prompt.format(user_input=query)
|
||||
|
||||
if hint_text:
|
||||
if user_prompt[-1] != "\n":
|
||||
user_prompt += "\n"
|
||||
user_prompt += "\nHint: " + hint_text
|
||||
|
||||
return user_prompt.strip()
|
||||
|
||||
|
||||
def form_tool_section_text(
|
||||
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
|
||||
) -> str | None:
|
||||
if not tools and not retrieval_enabled:
|
||||
return None
|
||||
|
||||
if retrieval_enabled and tools:
|
||||
tools.append(
|
||||
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
|
||||
)
|
||||
|
||||
tools_intro = []
|
||||
if tools:
|
||||
num_tools = len(tools)
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||
|
||||
prefix = "Must be one of " if num_tools > 1 else "Must be "
|
||||
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
return template.format(
|
||||
tool_overviews=tools_intro_text, tool_names=tool_names_text
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_FOLLOWUP,
|
||||
ignore_hint: bool = False,
|
||||
) -> str:
|
||||
# If multi-line query, it likely confuses the model more than helps
|
||||
if "\n" not in query:
|
||||
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
|
||||
else:
|
||||
optional_reminder = ""
|
||||
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"\nHint: {hint_text}\n"
|
||||
else:
|
||||
hint_text_spaced = ""
|
||||
|
||||
return tool_followup_prompt.format(
|
||||
tool_output=tool_output,
|
||||
optional_reminder=optional_reminder,
|
||||
hint=hint_text_spaced,
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_less_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
|
||||
) -> str:
|
||||
hint = f"Hint: {hint_text}" if hint_text else ""
|
||||
return tool_followup_prompt.format(
|
||||
context_str=tool_output, user_query=query, hint_text=hint
|
||||
).strip()
|
||||
@@ -1,32 +0,0 @@
|
||||
import bs4
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
Args:
|
||||
base_url (str): The base url of the Confluence instance
|
||||
content_url (str): The url of the page or attachment download url
|
||||
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def get_used_attachments(text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
|
||||
Returns:
|
||||
list[str]: List of filenames currently in use by the page text
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
@@ -1,796 +0,0 @@
|
||||
import io
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Collection
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.confluence.confluence_utils import (
|
||||
build_confluence_document_id,
|
||||
)
|
||||
from danswer.connectors.confluence.confluence_utils import get_used_attachments
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.html_utils import format_document_soup
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Potential Improvements
|
||||
# 1. Include attachments, etc
|
||||
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
|
||||
|
||||
NO_PERMISSIONS_TO_VIEW_ATTACHMENTS_ERROR_STR = (
|
||||
"User not permitted to view attachments on content"
|
||||
)
|
||||
NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
|
||||
"No parent or not permitted to view content with id"
|
||||
)
|
||||
|
||||
|
||||
class DanswerConfluence(Confluence):
|
||||
"""
|
||||
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
|
||||
This is necessary because the default Confluence class does not properly support cql expansions.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
|
||||
super(DanswerConfluence, self).__init__(url, *args, **kwargs)
|
||||
|
||||
def danswer_cql(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
cursor: str | None = None,
|
||||
limit: int = 500,
|
||||
include_archived_spaces: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
url_suffix = f"rest/api/content/search?cql={cql}"
|
||||
if expand:
|
||||
url_suffix += f"&expand={expand}"
|
||||
if cursor:
|
||||
url_suffix += f"&cursor={cursor}"
|
||||
url_suffix += f"&limit={limit}"
|
||||
if include_archived_spaces:
|
||||
url_suffix += "&includeArchivedSpaces=true"
|
||||
try:
|
||||
response = self.get(url_suffix)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
user_not_found = "Unknown User"
|
||||
|
||||
get_user_details_by_accountid = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_user_details_by_accountid
|
||||
)
|
||||
try:
|
||||
return get_user_details_by_accountid(user_id).get("displayName", user_not_found)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unable to get the User Display Name with the id: '{user_id}' - {e}"
|
||||
)
|
||||
return user_not_found
|
||||
|
||||
|
||||
def parse_html_page(text: str, confluence_client: DanswerConfluence) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.get("ri:userkey")
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
|
||||
)
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(user_id, confluence_client))
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def _comment_dfs(
|
||||
comments_str: str,
|
||||
comment_pages: Collection[dict[str, Any]],
|
||||
confluence_client: DanswerConfluence,
|
||||
) -> str:
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
for comment_page in comment_pages:
|
||||
comment_html = comment_page["body"]["storage"]["value"]
|
||||
comments_str += "\nComment:\n" + parse_html_page(
|
||||
comment_html, confluence_client
|
||||
)
|
||||
try:
|
||||
child_comment_pages = get_page_child_by_type(
|
||||
comment_page["id"],
|
||||
type="comment",
|
||||
start=None,
|
||||
limit=None,
|
||||
expand="body.storage.value",
|
||||
)
|
||||
comments_str = _comment_dfs(
|
||||
comments_str, child_comment_pages, confluence_client
|
||||
)
|
||||
except HTTPError as e:
|
||||
# not the cleanest, but I'm not aware of a nicer way to check the error
|
||||
if NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR not in str(e):
|
||||
raise
|
||||
|
||||
return comments_str
|
||||
|
||||
|
||||
def _datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime.fromisoformat(datetime_string)
|
||||
|
||||
if datetime_object.tzinfo is None:
|
||||
# If no timezone info, assume it is UTC
|
||||
datetime_object = datetime_object.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# If not in UTC, translate it
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
class RecursiveIndexer:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
confluence_client: Confluence,
|
||||
index_recursively: bool,
|
||||
origin_page_id: str,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.confluence_client = confluence_client
|
||||
self.index_recursively = index_recursively
|
||||
self.origin_page_id = origin_page_id
|
||||
self.pages = self.recurse_children_pages(self.origin_page_id)
|
||||
|
||||
def get_origin_page(self) -> list[dict[str, Any]]:
|
||||
return [self._fetch_origin_page()]
|
||||
|
||||
def get_pages(self) -> list[dict[str, Any]]:
|
||||
return self.pages
|
||||
|
||||
def _fetch_origin_page(self) -> dict[str, Any]:
|
||||
get_page_by_id = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_by_id
|
||||
)
|
||||
try:
|
||||
origin_page = get_page_by_id(
|
||||
self.origin_page_id, expand="body.storage.value,version,space"
|
||||
)
|
||||
return origin_page
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Appending origin page with id {self.origin_page_id} failed: {e}"
|
||||
)
|
||||
return {}
|
||||
|
||||
def recurse_children_pages(
|
||||
self,
|
||||
page_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
pages: list[dict[str, Any]] = []
|
||||
queue: list[str] = [page_id]
|
||||
visited_pages: set[str] = set()
|
||||
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
while queue:
|
||||
current_page_id = queue.pop(0)
|
||||
if current_page_id in visited_pages:
|
||||
continue
|
||||
visited_pages.add(current_page_id)
|
||||
|
||||
try:
|
||||
# Fetch the page itself
|
||||
page = self.confluence_client.get_page_by_id(
|
||||
current_page_id, expand="body.storage.value,version,space"
|
||||
)
|
||||
pages.append(page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch page {current_page_id}: {e}")
|
||||
continue
|
||||
|
||||
if not self.index_recursively:
|
||||
continue
|
||||
|
||||
# Fetch child pages
|
||||
start = 0
|
||||
while True:
|
||||
child_pages_response = get_page_child_by_type(
|
||||
current_page_id,
|
||||
type="page",
|
||||
start=start,
|
||||
limit=self.batch_size,
|
||||
expand="",
|
||||
)
|
||||
if not child_pages_response:
|
||||
break
|
||||
for child_page in child_pages_response:
|
||||
child_page_id = child_page["id"]
|
||||
queue.append(child_page_id)
|
||||
start += len(child_pages_response)
|
||||
|
||||
return pages
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_base: str,
|
||||
is_cloud: bool,
|
||||
space: str = "",
|
||||
page_id: str = "",
|
||||
index_recursively: bool = True,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
# if a page has one of the labels specified in this list, we will just
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
cql_query: str | None = None,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.recursive_indexer: RecursiveIndexer | None = None
|
||||
self.index_recursively = False if cql_query else index_recursively
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
self.page_id = "" if cql_query else page_id
|
||||
self.space_level_scan = bool(not self.page_id)
|
||||
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
self.confluence_client: DanswerConfluence | None = None
|
||||
|
||||
# if a cql_query is provided, we will use it to fetch the pages
|
||||
# if no cql_query is provided, we will use the space to fetch the pages
|
||||
# if no space is provided and no cql_query, we will default to fetching all pages, regardless of space
|
||||
if cql_query:
|
||||
self.cql_query = cql_query
|
||||
elif space:
|
||||
self.cql_query = f"type=page and space='{space}'"
|
||||
else:
|
||||
self.cql_query = "type=page"
|
||||
|
||||
logger.info(
|
||||
f"wiki_base: {self.wiki_base}, space: {space}, page_id: {self.page_id},"
|
||||
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively},"
|
||||
+ f" cql_query: {self.cql_query}"
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
username = credentials["confluence_username"]
|
||||
access_token = credentials["confluence_access_token"]
|
||||
self.confluence_client = DanswerConfluence(
|
||||
url=self.wiki_base,
|
||||
username=username if self.is_cloud else None,
|
||||
password=access_token if self.is_cloud else None,
|
||||
token=access_token if not self.is_cloud else None,
|
||||
)
|
||||
return None
|
||||
|
||||
def _fetch_pages(
|
||||
self,
|
||||
cursor: str | None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
if self.confluence_client is None:
|
||||
raise Exception("Confluence client is not initialized")
|
||||
|
||||
def _fetch_space(
|
||||
cursor: str | None, batch_size: int
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
if not self.confluence_client:
|
||||
raise Exception("Confluence client is not initialized")
|
||||
get_all_pages = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.danswer_cql
|
||||
)
|
||||
|
||||
include_archived_spaces = (
|
||||
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
|
||||
if not self.is_cloud
|
||||
else False
|
||||
)
|
||||
|
||||
try:
|
||||
response = get_all_pages(
|
||||
cql=self.cql_query,
|
||||
cursor=cursor,
|
||||
limit=batch_size,
|
||||
expand="body.storage.value,version,space",
|
||||
include_archived_spaces=include_archived_spaces,
|
||||
)
|
||||
pages = response.get("results", [])
|
||||
next_cursor = None
|
||||
if "_links" in response and "next" in response["_links"]:
|
||||
next_link = response["_links"]["next"]
|
||||
parsed_url = urlparse(next_link)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
cursor_list = query_params.get("cursor", [])
|
||||
if cursor_list:
|
||||
next_cursor = cursor_list[0]
|
||||
return pages, next_cursor
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Batch failed with cql {self.cql_query} with cursor {cursor} "
|
||||
f"and size {batch_size}, processing pages individually..."
|
||||
)
|
||||
|
||||
view_pages: list[dict[str, Any]] = []
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
response = get_all_pages(
|
||||
cql=self.cql_query,
|
||||
cursor=cursor,
|
||||
limit=1,
|
||||
expand="body.view.value,version,space",
|
||||
include_archived_spaces=include_archived_spaces,
|
||||
)
|
||||
pages = response.get("results", [])
|
||||
view_pages.extend(pages)
|
||||
if "_links" in response and "next" in response["_links"]:
|
||||
next_link = response["_links"]["next"]
|
||||
parsed_url = urlparse(next_link)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
cursor_list = query_params.get("cursor", [])
|
||||
if cursor_list:
|
||||
cursor = cursor_list[0]
|
||||
else:
|
||||
cursor = None
|
||||
else:
|
||||
cursor = None
|
||||
break
|
||||
except HTTPError as e:
|
||||
logger.warning(
|
||||
f"Page failed with cql {self.cql_query} with cursor {cursor}, "
|
||||
f"trying alternative expand option: {e}"
|
||||
)
|
||||
response = get_all_pages(
|
||||
cql=self.cql_query,
|
||||
cursor=cursor,
|
||||
limit=1,
|
||||
expand="body.view.value,version,space",
|
||||
)
|
||||
pages = response.get("results", [])
|
||||
view_pages.extend(pages)
|
||||
if "_links" in response and "next" in response["_links"]:
|
||||
next_link = response["_links"]["next"]
|
||||
parsed_url = urlparse(next_link)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
cursor_list = query_params.get("cursor", [])
|
||||
if cursor_list:
|
||||
cursor = cursor_list[0]
|
||||
else:
|
||||
cursor = None
|
||||
else:
|
||||
cursor = None
|
||||
break
|
||||
|
||||
return view_pages, cursor
|
||||
|
||||
def _fetch_page() -> tuple[list[dict[str, Any]], str | None]:
|
||||
if self.confluence_client is None:
|
||||
raise Exception("Confluence client is not initialized")
|
||||
|
||||
if self.recursive_indexer is None:
|
||||
self.recursive_indexer = RecursiveIndexer(
|
||||
origin_page_id=self.page_id,
|
||||
batch_size=self.batch_size,
|
||||
confluence_client=self.confluence_client,
|
||||
index_recursively=self.index_recursively,
|
||||
)
|
||||
|
||||
pages = self.recursive_indexer.get_pages()
|
||||
return pages, None # Since we fetched all pages, no cursor
|
||||
|
||||
try:
|
||||
pages, next_cursor = (
|
||||
_fetch_space(cursor, self.batch_size)
|
||||
if self.space_level_scan
|
||||
else _fetch_page()
|
||||
)
|
||||
return pages, next_cursor
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
logger.exception("Ran into exception when fetching pages from Confluence")
|
||||
return [], None
|
||||
|
||||
def _fetch_comments(self, confluence_client: Confluence, page_id: str) -> str:
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
try:
|
||||
comment_pages = list(
|
||||
get_page_child_by_type(
|
||||
page_id,
|
||||
type="comment",
|
||||
start=None,
|
||||
limit=None,
|
||||
expand="body.storage.value",
|
||||
)
|
||||
)
|
||||
return _comment_dfs("", comment_pages, confluence_client)
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Ran into exception when fetching comments from Confluence"
|
||||
)
|
||||
return ""
|
||||
|
||||
def _fetch_labels(self, confluence_client: Confluence, page_id: str) -> list[str]:
|
||||
get_page_labels = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_labels
|
||||
)
|
||||
try:
|
||||
labels_response = get_page_labels(page_id)
|
||||
return [label["name"] for label in labels_response["results"]]
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
logger.exception("Ran into exception when fetching labels from Confluence")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _attachment_to_download_link(
|
||||
cls, confluence_client: Confluence, attachment: dict[str, Any]
|
||||
) -> str:
|
||||
return confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
@classmethod
|
||||
def _attachment_to_content(
|
||||
cls,
|
||||
confluence_client: Confluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
return None
|
||||
|
||||
download_link = cls._attachment_to_download_link(confluence_client, attachment)
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
|
||||
def _fetch_attachments(
|
||||
self, confluence_client: Confluence, page_id: str, files_in_use: list[str]
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
files_attachment_content: list[str] = []
|
||||
|
||||
get_attachments_from_content = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_attachments_from_content
|
||||
)
|
||||
|
||||
try:
|
||||
expand = "history.lastUpdated,metadata.labels"
|
||||
attachments_container = get_attachments_from_content(
|
||||
page_id, start=None, limit=None, expand=expand
|
||||
)
|
||||
for attachment in attachments_container.get("results", []):
|
||||
if attachment["title"] not in files_in_use:
|
||||
unused_attachments.append(attachment)
|
||||
continue
|
||||
|
||||
attachment_content = self._attachment_to_content(
|
||||
confluence_client, attachment
|
||||
)
|
||||
if attachment_content:
|
||||
files_attachment_content.append(attachment_content)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(
|
||||
e, HTTPError
|
||||
) and NO_PERMISSIONS_TO_VIEW_ATTACHMENTS_ERROR_STR in str(e):
|
||||
logger.warning(
|
||||
f"User does not have access to attachments on page '{page_id}'"
|
||||
)
|
||||
return "", []
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
logger.exception(
|
||||
f"Ran into exception when fetching attachments from Confluence: {e}"
|
||||
)
|
||||
|
||||
return "\n".join(files_attachment_content), unused_attachments
|
||||
|
||||
def _get_doc_batch(
|
||||
self, cursor: str | None
|
||||
) -> tuple[list[Any], str | None, list[dict[str, Any]]]:
|
||||
if self.confluence_client is None:
|
||||
raise Exception("Confluence client is not initialized")
|
||||
|
||||
doc_batch: list[Any] = []
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
|
||||
batch, next_cursor = self._fetch_pages(cursor)
|
||||
|
||||
for page in batch:
|
||||
last_modified = _datetime_from_string(page["version"]["when"])
|
||||
author = page["version"].get("by", {}).get("email")
|
||||
|
||||
page_id = page["id"]
|
||||
|
||||
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
else:
|
||||
page_labels = []
|
||||
|
||||
# check disallowed labels
|
||||
if self.labels_to_skip:
|
||||
label_intersection = self.labels_to_skip.intersection(page_labels)
|
||||
if label_intersection:
|
||||
logger.info(
|
||||
f"Page with ID '{page_id}' has a label which has been "
|
||||
f"designated as disallowed: {label_intersection}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
page_html = (
|
||||
page["body"].get("storage", page["body"].get("view", {})).get("value")
|
||||
)
|
||||
# The url and the id are the same
|
||||
page_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"]
|
||||
)
|
||||
if not page_html:
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
|
||||
files_in_use = get_used_attachments(page_html)
|
||||
attachment_text, unused_page_attachments = self._fetch_attachments(
|
||||
self.confluence_client, page_id, files_in_use
|
||||
)
|
||||
unused_attachments.extend(unused_page_attachments)
|
||||
|
||||
page_text += "\n" + attachment_text if attachment_text else ""
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
doc_metadata: dict[str, str | list[str]] = {
|
||||
"Wiki Space Name": page["space"]["name"]
|
||||
}
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=page_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author)] if author else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
)
|
||||
return (
|
||||
doc_batch,
|
||||
next_cursor,
|
||||
unused_attachments,
|
||||
)
|
||||
|
||||
def _get_attachment_batch(
|
||||
self,
|
||||
start_ind: int,
|
||||
attachments: list[dict[str, Any]],
|
||||
time_filter: Callable[[datetime], bool] | None = None,
|
||||
) -> tuple[list[Any], int]:
|
||||
doc_batch: list[Any] = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
end_ind = min(start_ind + self.batch_size, len(attachments))
|
||||
|
||||
for attachment in attachments[start_ind:end_ind]:
|
||||
last_updated = _datetime_from_string(
|
||||
attachment["history"]["lastUpdated"]["when"]
|
||||
)
|
||||
|
||||
if time_filter and not time_filter(last_updated):
|
||||
continue
|
||||
|
||||
# The url and the id are the same
|
||||
attachment_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["download"]
|
||||
)
|
||||
attachment_content = self._attachment_to_content(
|
||||
self.confluence_client, attachment
|
||||
)
|
||||
if attachment_content is None:
|
||||
continue
|
||||
|
||||
creator_email = attachment["history"]["createdBy"].get("email")
|
||||
|
||||
comment = attachment["metadata"].get("comment", "")
|
||||
doc_metadata: dict[str, Any] = {"comment": comment}
|
||||
|
||||
attachment_labels: list[str] = []
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
for label in attachment["metadata"]["labels"]["results"]:
|
||||
attachment_labels.append(label["name"])
|
||||
|
||||
doc_metadata["labels"] = attachment_labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=attachment_url,
|
||||
sections=[Section(link=attachment_url, text=attachment_content)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment["title"],
|
||||
doc_updated_at=last_updated,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=creator_email)]
|
||||
if creator_email
|
||||
else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return doc_batch, end_ind - start_ind
|
||||
|
||||
def _handle_batch_retrieval(
|
||||
self,
|
||||
start: float | None = None,
|
||||
end: float | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_time = datetime.fromtimestamp(start, tz=timezone.utc) if start else None
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc) if end else None
|
||||
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
cursor = None
|
||||
while True:
|
||||
doc_batch, cursor, new_unused_attachments = self._get_doc_batch(cursor)
|
||||
unused_attachments.extend(new_unused_attachments)
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
# Process attachments if any
|
||||
start_ind = 0
|
||||
while True:
|
||||
attachment_batch, num_attachments = self._get_attachment_batch(
|
||||
start_ind=start_ind,
|
||||
attachments=unused_attachments,
|
||||
time_filter=(lambda t: start_time <= t <= end_time)
|
||||
if start_time and end_time
|
||||
else None,
|
||||
)
|
||||
|
||||
start_ind += num_attachments
|
||||
if attachment_batch:
|
||||
yield attachment_batch
|
||||
|
||||
if num_attachments < self.batch_size:
|
||||
break
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._handle_batch_retrieval()
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
return self._handle_batch_retrieval(start=start, end=end)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = ConfluenceConnector(
|
||||
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
space=os.environ["CONFLUENCE_TEST_SPACE"],
|
||||
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
|
||||
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
|
||||
index_recursively=True,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
@@ -1,134 +0,0 @@
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from redis.exceptions import ConnectionError
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
max_retries = 5
|
||||
starting_delay = 5
|
||||
backoff = 2
|
||||
|
||||
# max_delay is used when the server doesn't hand back "Retry-After"
|
||||
# and we have to decide the retry delay ourselves
|
||||
max_delay = 30 # Atlassian uses max_delay = 30 in their examples
|
||||
|
||||
# max_retry_after is used when we do get a "Retry-After" header
|
||||
max_retry_after = 300 # should we really cap the maximum retry delay?
|
||||
|
||||
NEXT_RETRY_KEY = BaseConnector.REDIS_KEY_PREFIX + "confluence_next_retry"
|
||||
|
||||
# for testing purposes, rate limiting is written to fall back to a simpler
|
||||
# rate limiting approach when redis is not available
|
||||
r = get_redis_client()
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# if multiple connectors are waiting for the next attempt, there could be an issue
|
||||
# where many connectors are "released" onto the server at the same time.
|
||||
# That's not ideal ... but coming up with a mechanism for queueing
|
||||
# all of these connectors is a bigger problem that we want to take on
|
||||
# right now
|
||||
try:
|
||||
next_attempt = r.get(NEXT_RETRY_KEY)
|
||||
if next_attempt is None:
|
||||
next_attempt = 0
|
||||
else:
|
||||
next_attempt = int(cast(int, next_attempt))
|
||||
|
||||
# TODO: all connectors need to be interruptible moving forward
|
||||
while time.monotonic() < next_attempt:
|
||||
time.sleep(1)
|
||||
except ConnectionError:
|
||||
pass
|
||||
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
retry_after = None
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
if retry_after > max_retry_after:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
|
||||
)
|
||||
retry_after = max_delay
|
||||
|
||||
logger.warning(
|
||||
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
try:
|
||||
r.set(
|
||||
NEXT_RETRY_KEY,
|
||||
math.ceil(time.monotonic() + retry_after),
|
||||
)
|
||||
except ConnectionError:
|
||||
pass
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limit hit. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
|
||||
try:
|
||||
r.set(NEXT_RETRY_KEY, delay_until)
|
||||
except ConnectionError:
|
||||
while time.monotonic() < delay_until:
|
||||
time.sleep(1)
|
||||
else:
|
||||
# re-raise, let caller handle
|
||||
raise
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
logger.warning(f"Confluence Internal Error, retrying... {e}")
|
||||
delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
try:
|
||||
r.set(NEXT_RETRY_KEY, delay_until)
|
||||
except ConnectionError:
|
||||
while time.monotonic() < delay_until:
|
||||
time.sleep(1)
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
raise e
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
@@ -1,321 +0,0 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
PROJECT_URL_PAT = "projects"
|
||||
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
|
||||
|
||||
|
||||
def extract_jira_project(url: str) -> tuple[str, str]:
|
||||
parsed_url = urlparse(url)
|
||||
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
|
||||
|
||||
# Split the path by '/' and find the position of 'projects' to get the project name
|
||||
split_path = parsed_url.path.split("/")
|
||||
if PROJECT_URL_PAT in split_path:
|
||||
project_pos = split_path.index(PROJECT_URL_PAT)
|
||||
if len(split_path) > project_pos + 1:
|
||||
jira_project = split_path[project_pos + 1]
|
||||
else:
|
||||
raise ValueError("No project name found in the URL")
|
||||
else:
|
||||
raise ValueError("'projects' not found in the URL")
|
||||
|
||||
return jira_base, jira_project
|
||||
|
||||
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
texts = []
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
texts.append(item["text"])
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
comment_strs = []
|
||||
for comment in jira.fields.comment.comments:
|
||||
try:
|
||||
body_text = (
|
||||
comment.body
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(comment.raw["body"])
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(comment, "author")
|
||||
and hasattr(comment.author, "emailAddress")
|
||||
and comment.author.emailAddress in comment_email_blacklist
|
||||
):
|
||||
continue # Skip adding comment if author's email is in blacklist
|
||||
|
||||
comment_strs.append(body_text)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process comment due to an error: {e}")
|
||||
continue
|
||||
|
||||
return comment_strs
|
||||
|
||||
|
||||
def fetch_jira_issues_batch(
|
||||
jql: str,
|
||||
start_index: int,
|
||||
jira_client: JIRA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
) -> tuple[list[Document], int]:
|
||||
doc_batch = []
|
||||
|
||||
batch = jira_client.search_issues(
|
||||
jql,
|
||||
startAt=start_index,
|
||||
maxResults=batch_size,
|
||||
)
|
||||
|
||||
for jira in batch:
|
||||
if type(jira) != Issue:
|
||||
logger.warning(f"Found Jira object not of type Issue {jira}")
|
||||
continue
|
||||
|
||||
if labels_to_skip and any(
|
||||
label in jira.fields.labels for label in labels_to_skip
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping {jira.key} because it has a label to skip. Found "
|
||||
f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
continue
|
||||
|
||||
description = (
|
||||
jira.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {jira.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
continue
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
people = set()
|
||||
try:
|
||||
people.add(
|
||||
BasicExpertInfo(
|
||||
display_name=jira.fields.creator.displayName,
|
||||
email=jira.fields.creator.emailAddress,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
try:
|
||||
people.add(
|
||||
BasicExpertInfo(
|
||||
display_name=jira.fields.assignee.displayName, # type: ignore
|
||||
email=jira.fields.assignee.emailAddress, # type: ignore
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=jira.fields.summary,
|
||||
doc_updated_at=time_str_to_utc(jira.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
)
|
||||
return doc_batch, len(batch)
|
||||
|
||||
|
||||
class JiraConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
jira_project_url: str,
|
||||
comment_email_blacklist: list[str] | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
# if a ticket has one of the labels specified in this list, we will just
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# tickets.
|
||||
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.jira_base, self.jira_project = extract_jira_project(jira_project_url)
|
||||
self.jira_client: JIRA | None = None
|
||||
self._comment_email_blacklist = comment_email_blacklist or []
|
||||
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
|
||||
@property
|
||||
def comment_email_blacklist(self) -> tuple:
|
||||
return tuple(email.strip() for email in self._comment_email_blacklist)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
api_token = credentials["jira_api_token"]
|
||||
# if user provide an email we assume it's cloud
|
||||
if "jira_user_email" in credentials:
|
||||
email = credentials["jira_user_email"]
|
||||
self.jira_client = JIRA(
|
||||
basic_auth=(email, api_token),
|
||||
server=self.jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
else:
|
||||
self.jira_client = JIRA(
|
||||
token_auth=api_token,
|
||||
server=self.jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=f"project = {quoted_project}",
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
)
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
start_ind += fetched_batch_size
|
||||
if fetched_batch_size < self.batch_size:
|
||||
break
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
end_date_str = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
jql = (
|
||||
f"project = {quoted_project} AND "
|
||||
f"updated >= '{start_date_str}' AND "
|
||||
f"updated <= '{end_date_str}'"
|
||||
)
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=jql,
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
)
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
start_ind += fetched_batch_size
|
||||
if fetched_batch_size < self.batch_size:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = JiraConnector(
|
||||
os.environ["JIRA_PROJECT_URL"], comment_email_blacklist=[]
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"jira_user_email": os.environ["JIRA_USER_EMAIL"],
|
||||
"jira_api_token": os.environ["JIRA_API_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
@@ -1,92 +0,0 @@
|
||||
"""Module with custom fields processing functions"""
|
||||
from typing import Any
|
||||
from typing import List
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import CustomFieldOption
|
||||
from jira.resources import Issue
|
||||
from jira.resources import User
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CustomFieldExtractor:
|
||||
@staticmethod
|
||||
def _process_custom_field_value(value: Any) -> str:
|
||||
"""
|
||||
Process a custom field value to a string
|
||||
"""
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
elif isinstance(value, CustomFieldOption):
|
||||
return value.value
|
||||
elif isinstance(value, User):
|
||||
return value.displayName
|
||||
elif isinstance(value, List):
|
||||
return " ".join(
|
||||
[CustomFieldExtractor._process_custom_field_value(v) for v in value]
|
||||
)
|
||||
else:
|
||||
return str(value)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing custom field value {value}: {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def get_issue_custom_fields(
|
||||
jira: Issue, custom_fields: dict, max_value_length: int = 250
|
||||
) -> dict:
|
||||
"""
|
||||
Process all custom fields of an issue to a dictionary of strings
|
||||
:param jira: jira_issue, bug or similar
|
||||
:param custom_fields: custom fields dictionary
|
||||
:param max_value_length: maximum length of the value to be processed, if exceeded, it will be truncated
|
||||
"""
|
||||
|
||||
issue_custom_fields = {
|
||||
custom_fields[key]: value
|
||||
for key, value in jira.fields.__dict__.items()
|
||||
if value and key in custom_fields.keys()
|
||||
}
|
||||
|
||||
processed_fields = {}
|
||||
|
||||
if issue_custom_fields:
|
||||
for key, value in issue_custom_fields.items():
|
||||
processed = CustomFieldExtractor._process_custom_field_value(value)
|
||||
# We need max length parameter, because there are some plugins that often has very long description
|
||||
# and there is just a technical information so we just avoid long values
|
||||
if len(processed) < max_value_length:
|
||||
processed_fields[key] = processed
|
||||
|
||||
return processed_fields
|
||||
|
||||
@staticmethod
|
||||
def get_all_custom_fields(jira_client: JIRA) -> dict:
|
||||
"""Get all custom fields from Jira"""
|
||||
fields = jira_client.fields()
|
||||
fields_dct = {
|
||||
field["id"]: field["name"] for field in fields if field["custom"] is True
|
||||
}
|
||||
return fields_dct
|
||||
|
||||
|
||||
class CommonFieldExtractor:
|
||||
@staticmethod
|
||||
def get_issue_common_fields(jira: Issue) -> dict:
|
||||
return {
|
||||
"Priority": jira.fields.priority.name if jira.fields.priority else None,
|
||||
"Reporter": jira.fields.reporter.displayName
|
||||
if jira.fields.reporter
|
||||
else None,
|
||||
"Assignee": jira.fields.assignee.displayName
|
||||
if jira.fields.assignee
|
||||
else None,
|
||||
"Status": jira.fields.status.name if jira.fields.status else None,
|
||||
"Resolution": jira.fields.resolution.name
|
||||
if jira.fields.resolution
|
||||
else None,
|
||||
}
|
||||
@@ -1,221 +0,0 @@
|
||||
from base64 import urlsafe_b64decode
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient import discovery # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.gmail.connector_auth import (
|
||||
get_gmail_creds_for_authorized_user,
|
||||
)
|
||||
from danswer.connectors.gmail.connector_auth import (
|
||||
get_gmail_creds_for_service_account,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class GmailConnector(LoadConnector, PollConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going thorugh
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(
|
||||
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
||||
)
|
||||
creds = get_gmail_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = creds.to_json() if creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
if GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
creds = get_gmail_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
|
||||
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Gmail - unknown credential structure."
|
||||
)
|
||||
|
||||
self.creds = creds
|
||||
return new_creds_dict
|
||||
|
||||
def _get_email_body(self, payload: dict[str, Any]) -> str:
|
||||
parts = payload.get("parts", [])
|
||||
email_body = ""
|
||||
for part in parts:
|
||||
mime_type = part.get("mimeType")
|
||||
body = part.get("body")
|
||||
if mime_type == "text/plain":
|
||||
data = body.get("data", "")
|
||||
text = urlsafe_b64decode(data).decode()
|
||||
email_body += text
|
||||
return email_body
|
||||
|
||||
def _email_to_document(self, full_email: Dict[str, Any]) -> Document:
|
||||
email_id = full_email["id"]
|
||||
payload = full_email["payload"]
|
||||
headers = payload.get("headers")
|
||||
labels = full_email.get("labelIds", [])
|
||||
metadata = {}
|
||||
if headers:
|
||||
for header in headers:
|
||||
name = header.get("name").lower()
|
||||
value = header.get("value")
|
||||
if name in ["from", "to", "subject", "date", "cc", "bcc"]:
|
||||
metadata[name] = value
|
||||
email_data = ""
|
||||
for name, value in metadata.items():
|
||||
email_data += f"{name}: {value}\n"
|
||||
metadata["labels"] = labels
|
||||
logger.debug(f"{email_data}")
|
||||
email_body_text: str = self._get_email_body(payload)
|
||||
date_str = metadata.get("date")
|
||||
email_updated_at = time_str_to_utc(date_str) if date_str else None
|
||||
link = f"https://mail.google.com/mail/u/0/#inbox/{email_id}"
|
||||
return Document(
|
||||
id=email_id,
|
||||
sections=[Section(link=link, text=email_data + email_body_text)],
|
||||
source=DocumentSource.GMAIL,
|
||||
title=metadata.get("subject"),
|
||||
semantic_identifier=metadata.get("subject", "Untitled Email"),
|
||||
doc_updated_at=email_updated_at,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_time_range_query(
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str | None:
|
||||
query = ""
|
||||
if time_range_start is not None and time_range_start != 0:
|
||||
query += f"after:{int(time_range_start)}"
|
||||
if time_range_end is not None and time_range_end != 0:
|
||||
query += f" before:{int(time_range_end)}"
|
||||
query = query.strip()
|
||||
|
||||
if len(query) == 0:
|
||||
return None
|
||||
|
||||
return query
|
||||
|
||||
def _fetch_mails_from_gmail(
|
||||
self,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.creds is None:
|
||||
raise PermissionError("Not logged into Gmail")
|
||||
page_token = ""
|
||||
query = GmailConnector._build_time_range_query(time_range_start, time_range_end)
|
||||
service = discovery.build("gmail", "v1", credentials=self.creds)
|
||||
while page_token is not None:
|
||||
result = (
|
||||
service.users()
|
||||
.messages()
|
||||
.list(
|
||||
userId="me",
|
||||
pageToken=page_token,
|
||||
q=query,
|
||||
maxResults=self.batch_size,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
page_token = result.get("nextPageToken")
|
||||
messages = result.get("messages", [])
|
||||
doc_batch = []
|
||||
for message in messages:
|
||||
message_id = message["id"]
|
||||
msg = (
|
||||
service.users()
|
||||
.messages()
|
||||
.get(userId="me", id=message_id, format="full")
|
||||
.execute()
|
||||
)
|
||||
doc = self._email_to_document(msg)
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) > 0:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
yield from self._fetch_mails_from_gmail()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
yield from self._fetch_mails_from_gmail(start, end)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
import os
|
||||
|
||||
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
|
||||
if not service_account_json_path:
|
||||
raise ValueError(
|
||||
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
|
||||
)
|
||||
with open(service_account_json_path) as f:
|
||||
creds = json.load(f)
|
||||
|
||||
credentials_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: json.dumps(creds),
|
||||
}
|
||||
delegated_user = os.environ.get("GMAIL_DELEGATED_USER")
|
||||
if delegated_user:
|
||||
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
|
||||
|
||||
connector = GmailConnector()
|
||||
connector.load_credentials(
|
||||
json.loads(credentials_dict[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
)
|
||||
document_batch_generator = connector.load_from_state()
|
||||
for document_batch in document_batch_generator:
|
||||
print(document_batch)
|
||||
break
|
||||
@@ -1,197 +0,0 @@
|
||||
import json
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GMAIL_CRED_KEY
|
||||
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import GoogleAppCredentials
|
||||
from danswer.server.documents.models import GoogleServiceAccountKey
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_frontend_gmail_redirect() -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/gmail/auth/callback"
|
||||
|
||||
|
||||
def get_gmail_creds_for_authorized_user(
|
||||
token_json_str: str,
|
||||
) -> OAuthCredentials | None:
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
if creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logger.notice("Refreshed Gmail tokens.")
|
||||
return creds
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to refresh gmail access token due to: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_gmail_creds_for_service_account(
|
||||
service_account_key_json_str: str,
|
||||
) -> ServiceAccountCredentials | None:
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=SCOPES
|
||||
)
|
||||
if not creds.valid or not creds.expired:
|
||||
creds.refresh(Request())
|
||||
return creds if creds.valid else None
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Gmail Connector callback does not match expected"
|
||||
)
|
||||
|
||||
|
||||
def get_gmail_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_gmail_redirect(),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_gmail_redirect(),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def update_gmail_credential_access_tokens(
|
||||
auth_code: str,
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_gmail_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_gmail_redirect(),
|
||||
)
|
||||
flow.fetch_token(code=auth_code)
|
||||
creds = flow.credentials
|
||||
token_json_str = creds.to_json()
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
|
||||
|
||||
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
|
||||
return None
|
||||
return creds
|
||||
|
||||
|
||||
def build_service_account_creds(
|
||||
delegated_user_email: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_gmail_service_account_key()
|
||||
|
||||
credential_dict = {
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
|
||||
}
|
||||
if delegated_user_email:
|
||||
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
|
||||
|
||||
return CredentialBase(
|
||||
source=DocumentSource.GMAIL,
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
)
|
||||
|
||||
|
||||
def get_google_app_gmail_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
|
||||
|
||||
def delete_google_app_gmail_cred() -> None:
|
||||
get_kv_store().delete(KV_GMAIL_CRED_KEY)
|
||||
|
||||
|
||||
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_gmail_service_account_key(
|
||||
service_account_key: GoogleServiceAccountKey,
|
||||
) -> None:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_gmail_service_account_key() -> None:
|
||||
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
@@ -1,4 +0,0 @@
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "gmail_tokens"
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "gmail_delegated_user"
|
||||
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
@@ -1,556 +0,0 @@
|
||||
import io
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient import discovery # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_FOLLOW_SHORTCUTS
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import docx_to_text
|
||||
from danswer.file_processing.extract_file_text import pptx_to_text
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
|
||||
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
|
||||
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
|
||||
|
||||
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
POWERPOINT = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
PLAIN_TEXT = "text/plain"
|
||||
MARKDOWN = "text/markdown"
|
||||
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
|
||||
def _run_drive_file_query(
|
||||
service: discovery.Resource,
|
||||
query: str,
|
||||
continue_on_failure: bool,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
next_page_token = ""
|
||||
while next_page_token is not None:
|
||||
logger.debug(f"Running Google Drive fetch with query: {query}")
|
||||
results = add_retries(
|
||||
lambda: (
|
||||
service.files()
|
||||
.list(
|
||||
corpora="allDrives"
|
||||
if include_shared
|
||||
else "user", # needed to search through shared drives
|
||||
pageSize=batch_size,
|
||||
supportsAllDrives=include_shared,
|
||||
includeItemsFromAllDrives=include_shared,
|
||||
fields=(
|
||||
"nextPageToken, files(mimeType, id, name, permissions, "
|
||||
"modifiedTime, webViewLink, shortcutDetails)"
|
||||
),
|
||||
pageToken=next_page_token,
|
||||
q=query,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
next_page_token = results.get("nextPageToken")
|
||||
files = results["files"]
|
||||
for file in files:
|
||||
if follow_shortcuts and "shortcutDetails" in file:
|
||||
try:
|
||||
file_shortcut_points_to = add_retries(
|
||||
lambda: (
|
||||
service.files()
|
||||
.get(
|
||||
fileId=file["shortcutDetails"]["targetId"],
|
||||
supportsAllDrives=include_shared,
|
||||
fields="mimeType, id, name, modifiedTime, webViewLink, permissions, shortcutDetails",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
yield file_shortcut_points_to
|
||||
except HttpError:
|
||||
logger.error(
|
||||
f"Failed to follow shortcut with details: {file['shortcutDetails']}"
|
||||
)
|
||||
if continue_on_failure:
|
||||
continue
|
||||
raise
|
||||
else:
|
||||
yield file
|
||||
|
||||
|
||||
def _get_folder_id(
|
||||
service: discovery.Resource,
|
||||
parent_id: str,
|
||||
folder_name: str,
|
||||
include_shared: bool,
|
||||
follow_shortcuts: bool,
|
||||
) -> str | None:
|
||||
"""
|
||||
Get the ID of a folder given its name and the ID of its parent folder.
|
||||
"""
|
||||
query = f"'{parent_id}' in parents and name='{folder_name}' and "
|
||||
if follow_shortcuts:
|
||||
query += f"(mimeType='{DRIVE_FOLDER_TYPE}' or mimeType='{DRIVE_SHORTCUT_TYPE}')"
|
||||
else:
|
||||
query += f"mimeType='{DRIVE_FOLDER_TYPE}'"
|
||||
|
||||
# TODO: support specifying folder path in shared drive rather than just `My Drive`
|
||||
results = add_retries(
|
||||
lambda: (
|
||||
service.files()
|
||||
.list(
|
||||
q=query,
|
||||
spaces="drive",
|
||||
fields="nextPageToken, files(id, name, shortcutDetails)",
|
||||
supportsAllDrives=include_shared,
|
||||
includeItemsFromAllDrives=include_shared,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
items = results.get("files", [])
|
||||
|
||||
folder_id = None
|
||||
if items:
|
||||
if follow_shortcuts and "shortcutDetails" in items[0]:
|
||||
folder_id = items[0]["shortcutDetails"]["targetId"]
|
||||
else:
|
||||
folder_id = items[0]["id"]
|
||||
return folder_id
|
||||
|
||||
|
||||
def _get_folders(
|
||||
service: discovery.Resource,
|
||||
continue_on_failure: bool,
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
|
||||
if follow_shortcuts:
|
||||
query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") "
|
||||
|
||||
if folder_id:
|
||||
query += f"and '{folder_id}' in parents "
|
||||
query = query.rstrip() # remove the trailing space(s)
|
||||
|
||||
for file in _run_drive_file_query(
|
||||
service=service,
|
||||
query=query,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
):
|
||||
# Need to check this since file may have been a target of a shortcut
|
||||
# and not necessarily a folder
|
||||
if file["mimeType"] == DRIVE_FOLDER_TYPE:
|
||||
yield file
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def _get_files(
|
||||
service: discovery.Resource,
|
||||
continue_on_failure: bool,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
|
||||
if time_range_start is not None:
|
||||
time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
|
||||
query += f"and modifiedTime >= '{time_start}' "
|
||||
if time_range_end is not None:
|
||||
time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
|
||||
query += f"and modifiedTime <= '{time_stop}' "
|
||||
if folder_id:
|
||||
query += f"and '{folder_id}' in parents "
|
||||
query = query.rstrip() # remove the trailing space(s)
|
||||
|
||||
files = _run_drive_file_query(
|
||||
service=service,
|
||||
query=query,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def get_all_files_batched(
|
||||
service: discovery.Resource,
|
||||
continue_on_failure: bool,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
# if True, will fetch files in sub-folders of the specified folder ID.
|
||||
# Only applies if folder_id is specified.
|
||||
traverse_subfolders: bool = True,
|
||||
folder_ids_traversed: list[str] | None = None,
|
||||
) -> Iterator[list[GoogleDriveFileType]]:
|
||||
"""Gets all files matching the criteria specified by the args from Google Drive
|
||||
in batches of size `batch_size`.
|
||||
"""
|
||||
found_files = _get_files(
|
||||
service=service,
|
||||
continue_on_failure=continue_on_failure,
|
||||
time_range_start=time_range_start,
|
||||
time_range_end=time_range_end,
|
||||
folder_id=folder_id,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
yield from batch_generator(
|
||||
items=found_files,
|
||||
batch_size=batch_size,
|
||||
pre_batch_yield=lambda batch_files: logger.debug(
|
||||
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
|
||||
),
|
||||
)
|
||||
|
||||
if traverse_subfolders and folder_id is not None:
|
||||
folder_ids_traversed = folder_ids_traversed or []
|
||||
subfolders = _get_folders(
|
||||
service=service,
|
||||
folder_id=folder_id,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
for subfolder in subfolders:
|
||||
if subfolder["id"] not in folder_ids_traversed:
|
||||
logger.info("Fetching all files in subfolder: " + subfolder["name"])
|
||||
folder_ids_traversed.append(subfolder["id"])
|
||||
yield from get_all_files_batched(
|
||||
service=service,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
time_range_start=time_range_start,
|
||||
time_range_end=time_range_end,
|
||||
folder_id=subfolder["id"],
|
||||
traverse_subfolders=traverse_subfolders,
|
||||
folder_ids_traversed=folder_ids_traversed,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Skipping subfolder since already traversed: " + subfolder["name"]
|
||||
)
|
||||
|
||||
|
||||
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
mime_type = file["mimeType"]
|
||||
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
GDriveMimeType.SPREADSHEET.value,
|
||||
]:
|
||||
export_mime_type = (
|
||||
"text/plain"
|
||||
if mime_type != GDriveMimeType.SPREADSHEET.value
|
||||
else "text/csv"
|
||||
)
|
||||
return (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType=export_mime_type)
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
GDriveMimeType.PDF.value,
|
||||
]:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
if get_unstructured_api_key():
|
||||
return unstructured_to_text(
|
||||
file=io.BytesIO(response), file_name=file.get("name", file["id"])
|
||||
)
|
||||
|
||||
if mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return text
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
# optional list of folder paths e.g. "[My Folder/My Subfolder]"
|
||||
# if specified, will only index files in these folders
|
||||
folder_paths: list[str] | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
only_org_public: bool = GOOGLE_DRIVE_ONLY_ORG_PUBLIC,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.folder_paths = folder_paths or []
|
||||
self.batch_size = batch_size
|
||||
self.include_shared = include_shared
|
||||
self.follow_shortcuts = follow_shortcuts
|
||||
self.only_org_public = only_org_public
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
@staticmethod
|
||||
def _process_folder_paths(
|
||||
service: discovery.Resource,
|
||||
folder_paths: list[str],
|
||||
include_shared: bool,
|
||||
follow_shortcuts: bool,
|
||||
) -> list[str]:
|
||||
"""['Folder/Sub Folder'] -> ['<FOLDER_ID>']"""
|
||||
folder_ids: list[str] = []
|
||||
for path in folder_paths:
|
||||
folder_names = path.split("/")
|
||||
parent_id = "root"
|
||||
for folder_name in folder_names:
|
||||
found_parent_id = _get_folder_id(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
folder_name=folder_name,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
)
|
||||
if found_parent_id is None:
|
||||
raise ValueError(
|
||||
(
|
||||
f"Folder '{folder_name}' in path '{path}' "
|
||||
"not found in Google Drive"
|
||||
)
|
||||
)
|
||||
parent_id = found_parent_id
|
||||
folder_ids.append(parent_id)
|
||||
|
||||
return folder_ids
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going thorough
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
creds, new_creds_dict = get_google_drive_creds(credentials)
|
||||
self.creds = creds
|
||||
return new_creds_dict
|
||||
|
||||
def _fetch_docs_from_drive(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.creds is None:
|
||||
raise PermissionError("Not logged into Google Drive")
|
||||
|
||||
service = discovery.build("drive", "v3", credentials=self.creds)
|
||||
folder_ids: Sequence[str | None] = self._process_folder_paths(
|
||||
service, self.folder_paths, self.include_shared, self.follow_shortcuts
|
||||
)
|
||||
if not folder_ids:
|
||||
folder_ids = [None]
|
||||
|
||||
file_batches = chain(
|
||||
*[
|
||||
get_all_files_batched(
|
||||
service=service,
|
||||
continue_on_failure=self.continue_on_failure,
|
||||
include_shared=self.include_shared,
|
||||
follow_shortcuts=self.follow_shortcuts,
|
||||
batch_size=self.batch_size,
|
||||
time_range_start=start,
|
||||
time_range_end=end,
|
||||
folder_id=folder_id,
|
||||
traverse_subfolders=True,
|
||||
)
|
||||
for folder_id in folder_ids
|
||||
]
|
||||
)
|
||||
for files_batch in file_batches:
|
||||
doc_batch = []
|
||||
for file in files_batch:
|
||||
try:
|
||||
# Skip files that are shortcuts
|
||||
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
|
||||
logger.info("Ignoring Drive Shortcut Filetype")
|
||||
continue
|
||||
|
||||
if self.only_org_public:
|
||||
if "permissions" not in file:
|
||||
continue
|
||||
if not any(
|
||||
permission["type"] == "domain"
|
||||
for permission in file["permissions"]
|
||||
):
|
||||
continue
|
||||
try:
|
||||
text_contents = extract_text(file, service) or ""
|
||||
except HttpError as e:
|
||||
reason = (
|
||||
e.error_details[0]["reason"]
|
||||
if e.error_details
|
||||
else e.reason
|
||||
)
|
||||
message = (
|
||||
e.error_details[0]["message"]
|
||||
if e.error_details
|
||||
else e.reason
|
||||
)
|
||||
|
||||
# these errors don't represent a failure in the connector, but simply files
|
||||
# that can't / shouldn't be indexed
|
||||
ERRORS_TO_CONTINUE_ON = [
|
||||
"cannotExportFile",
|
||||
"exportSizeLimitExceeded",
|
||||
"cannotDownloadFile",
|
||||
]
|
||||
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
|
||||
logger.warning(
|
||||
f"Could not export file '{file['name']}' due to '{message}', skipping..."
|
||||
)
|
||||
continue
|
||||
|
||||
raise
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=file["webViewLink"],
|
||||
sections=[
|
||||
Section(link=file["webViewLink"], text=text_contents)
|
||||
],
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file["name"],
|
||||
doc_updated_at=datetime.fromisoformat(
|
||||
file["modifiedTime"]
|
||||
).astimezone(timezone.utc),
|
||||
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Ran into exception when pulling a file from Google Drive"
|
||||
)
|
||||
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
yield from self._fetch_docs_from_drive()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
# need to subtract 10 minutes from start time to account for modifiedTime
|
||||
# propogation if a document is modified, it takes some time for the API to
|
||||
# reflect these changes if we do not have an offset, then we may "miss" the
|
||||
# update when polling
|
||||
yield from self._fetch_docs_from_drive(start, end)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
import os
|
||||
|
||||
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
|
||||
if not service_account_json_path:
|
||||
raise ValueError(
|
||||
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
|
||||
)
|
||||
with open(service_account_json_path) as f:
|
||||
creds = json.load(f)
|
||||
|
||||
credentials_dict = {
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: json.dumps(creds),
|
||||
}
|
||||
delegated_user = os.environ.get("GOOGLE_DRIVE_DELEGATED_USER")
|
||||
if delegated_user:
|
||||
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
|
||||
|
||||
connector = GoogleDriveConnector(include_shared=True, follow_shortcuts=True)
|
||||
connector.load_credentials(credentials_dict)
|
||||
document_batch_generator = connector.load_from_state()
|
||||
for document_batch in document_batch_generator:
|
||||
print(document_batch)
|
||||
break
|
||||
@@ -1,229 +0,0 @@
|
||||
import json
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_drive.constants import BASE_SCOPES
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import GoogleAppCredentials
|
||||
from danswer.server.documents.models import GoogleServiceAccountKey
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_gdrive_scopes() -> list[str]:
|
||||
base_scopes: list[str] = BASE_SCOPES
|
||||
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
|
||||
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
|
||||
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
return base_scopes + permissions_scopes + groups_scopes
|
||||
return base_scopes + permissions_scopes
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect() -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
|
||||
|
||||
def get_google_drive_creds_for_authorized_user(
|
||||
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> OAuthCredentials | None:
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
if creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logger.notice("Refreshed Google Drive tokens.")
|
||||
return creds
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to refresh google drive access token due to: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> ServiceAccountCredentials | None:
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=scopes
|
||||
)
|
||||
if not creds.valid or not creds.expired:
|
||||
creds.refresh(Request())
|
||||
return creds if creds.valid else None
|
||||
|
||||
|
||||
def get_google_drive_creds(
|
||||
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
oauth_creds = get_google_drive_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str, scopes=scopes
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
service_creds = _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
service_creds = (
|
||||
service_creds.with_subject(delegated_user_email)
|
||||
if service_creds
|
||||
else None
|
||||
)
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = (
|
||||
oauth_creds or service_creds
|
||||
)
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Google Drive - unknown credential structure."
|
||||
)
|
||||
|
||||
return creds, new_creds_dict
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
)
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def update_credential_access_tokens(
|
||||
auth_code: str,
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
flow.fetch_token(code=auth_code)
|
||||
creds = flow.credentials
|
||||
token_json_str = creds.to_json()
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
|
||||
|
||||
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
|
||||
return None
|
||||
return creds
|
||||
|
||||
|
||||
def build_service_account_creds(
|
||||
source: DocumentSource,
|
||||
delegated_user_email: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_service_account_key()
|
||||
|
||||
credential_dict = {
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
|
||||
}
|
||||
if delegated_user_email:
|
||||
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
|
||||
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
|
||||
|
||||
def get_google_app_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_kv_store().store(KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
|
||||
|
||||
def delete_google_app_cred() -> None:
|
||||
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
|
||||
|
||||
def get_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
@@ -1,7 +0,0 @@
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
|
||||
|
||||
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
|
||||
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]
|
||||
@@ -1,153 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import DEBUG as LOG_LVL_DEBUG
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from rt.rest1 import ALL_QUEUES
|
||||
from rt.rest1 import Rt
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class RequestTrackerError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RequestTrackerConnector(PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def txn_link(self, tid: int, txn: int) -> str:
|
||||
return f"{self.rt_base_url}/Ticket/Display.html?id={tid}&txn={txn}"
|
||||
|
||||
def build_doc_sections_from_txn(
|
||||
self, connection: Rt, ticket_id: int
|
||||
) -> List[Section]:
|
||||
Sections: List[Section] = []
|
||||
|
||||
get_history_resp = connection.get_history(ticket_id)
|
||||
|
||||
if get_history_resp is None:
|
||||
raise RequestTrackerError(f"Ticket {ticket_id} cannot be found")
|
||||
|
||||
for tx in get_history_resp:
|
||||
Sections.append(
|
||||
Section(
|
||||
link=self.txn_link(ticket_id, int(tx["id"])),
|
||||
text="\n".join(
|
||||
[
|
||||
f"{k}:\n{v}\n" if k != "Attachments" else ""
|
||||
for (k, v) in tx.items()
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
return Sections
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
self.rt_username = credentials.get("requesttracker_username")
|
||||
self.rt_password = credentials.get("requesttracker_password")
|
||||
self.rt_base_url = credentials.get("requesttracker_base_url")
|
||||
return None
|
||||
|
||||
# This does not include RT file attachments yet.
|
||||
def _process_tickets(
|
||||
self, start: datetime, end: datetime
|
||||
) -> GenerateDocumentsOutput:
|
||||
if any([self.rt_username, self.rt_password, self.rt_base_url]) is None:
|
||||
raise ConnectorMissingCredentialError("requesttracker")
|
||||
|
||||
Rt0 = Rt(
|
||||
f"{self.rt_base_url}/REST/1.0/",
|
||||
self.rt_username,
|
||||
self.rt_password,
|
||||
)
|
||||
|
||||
Rt0.login()
|
||||
|
||||
d0 = start.strftime("%Y-%m-%d %H:%M:%S")
|
||||
d1 = end.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
tickets = Rt0.search(
|
||||
Queue=ALL_QUEUES,
|
||||
raw_query=f"Updated > '{d0}' AND Updated < '{d1}'",
|
||||
)
|
||||
|
||||
doc_batch: List[Document] = []
|
||||
|
||||
for ticket in tickets:
|
||||
ticket_keys_to_omit = ["id", "Subject"]
|
||||
tid: int = int(ticket["numerical_id"])
|
||||
ticketLink: str = f"{self.rt_base_url}/Ticket/Display.html?id={tid}"
|
||||
logger.info(f"Processing ticket {tid}")
|
||||
doc = Document(
|
||||
id=ticket["id"],
|
||||
# Will add title to the first section later in processing
|
||||
sections=[Section(link=ticketLink, text="")]
|
||||
+ self.build_doc_sections_from_txn(Rt0, tid),
|
||||
source=DocumentSource.REQUESTTRACKER,
|
||||
semantic_identifier=ticket["Subject"],
|
||||
metadata={
|
||||
key: value
|
||||
for key, value in ticket.items()
|
||||
if key not in ticket_keys_to_omit
|
||||
},
|
||||
)
|
||||
|
||||
doc_batch.append(doc)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
# Keep query short, only look behind 1 day at maximum
|
||||
one_day_ago: float = end - (24 * 60 * 60)
|
||||
_start: float = start if start > one_day_ago else one_day_ago
|
||||
start_datetime = datetime.fromtimestamp(_start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
yield from self._process_tickets(start_datetime, end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
logger.setLevel(LOG_LVL_DEBUG)
|
||||
rt_connector = RequestTrackerConnector()
|
||||
rt_connector.load_credentials(
|
||||
{
|
||||
"requesttracker_username": os.getenv("RT_USERNAME"),
|
||||
"requesttracker_password": os.getenv("RT_PASSWORD"),
|
||||
"requesttracker_base_url": os.getenv("RT_BASE_URL"),
|
||||
}
|
||||
)
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - (24 * 60 * 60) # 1 days
|
||||
latest_docs = rt_connector.poll_source(one_day_ago, current)
|
||||
|
||||
for doc in latest_docs:
|
||||
print(doc)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user