mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
495 Commits
anonymous_
...
bookstack_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fb85d53c9 | ||
|
|
3b92cf2f38 | ||
|
|
65485e0ea1 | ||
|
|
67028782f0 | ||
|
|
09b14c68ca | ||
|
|
8347bfe5ee | ||
|
|
bf175d0749 | ||
|
|
c892dd9c6f | ||
|
|
bf51ac5dc0 | ||
|
|
353c185856 | ||
|
|
7c96b7f24e | ||
|
|
31524a3eff | ||
|
|
c9f618798e | ||
|
|
11f6b44625 | ||
|
|
e82a25f49e | ||
|
|
5a9ec61446 | ||
|
|
9635522de8 | ||
|
|
630bdf71a3 | ||
|
|
47fd4fa233 | ||
|
|
2013beb9e0 | ||
|
|
466276161c | ||
|
|
c934892c68 | ||
|
|
1daa3a663d | ||
|
|
7324273233 | ||
|
|
2b2ba5478c | ||
|
|
045a41d929 | ||
|
|
e3bc7cc747 | ||
|
|
0826b035a2 | ||
|
|
cf0e3d1ff4 | ||
|
|
10c81f75e2 | ||
|
|
5ca898bde2 | ||
|
|
58b252727f | ||
|
|
86bd121806 | ||
|
|
9324f426c0 | ||
|
|
20d3efc86e | ||
|
|
ec0e55fd39 | ||
|
|
e441c899af | ||
|
|
f1fc8ac19b | ||
|
|
bc087fc20e | ||
|
|
ab8081c36b | ||
|
|
f371efc916 | ||
|
|
7fd5d31dbe | ||
|
|
2829e6715e | ||
|
|
bc7b4ec396 | ||
|
|
697f8bc1c6 | ||
|
|
3ba65214b8 | ||
|
|
6687d5d499 | ||
|
|
ec78f78f3c | ||
|
|
ed253e469a | ||
|
|
e3aafd95af | ||
|
|
3a704f1950 | ||
|
|
2bf8a7aee5 | ||
|
|
c2f3302aa0 | ||
|
|
7f4d1f27a0 | ||
|
|
b70db15622 | ||
|
|
e9492ce9ec | ||
|
|
35574369ed | ||
|
|
eff433bdc5 | ||
|
|
3260d793d1 | ||
|
|
1a7aca06b9 | ||
|
|
c6434db7eb | ||
|
|
667b9e04c5 | ||
|
|
29c84d7707 | ||
|
|
17c915b11b | ||
|
|
95ca592d6d | ||
|
|
e39a27fd6b | ||
|
|
26d3c952c6 | ||
|
|
53683e2f3c | ||
|
|
0c0113a481 | ||
|
|
c0f381e471 | ||
|
|
5ed83f1148 | ||
|
|
9db7b67a6c | ||
|
|
2850048c6b | ||
|
|
61058e5fcd | ||
|
|
c87261cda7 | ||
|
|
e030b0a6fc | ||
|
|
61136975ad | ||
|
|
0c74bbf9ed | ||
|
|
12b2126e69 | ||
|
|
037943c6ff | ||
|
|
f9485b1325 | ||
|
|
552a0630fe | ||
|
|
5bf520d8b8 | ||
|
|
7dc5a77946 | ||
|
|
03abd4a1bc | ||
|
|
16d6d708f6 | ||
|
|
9740ed32b5 | ||
|
|
b56877cc2e | ||
|
|
da5c83a96d | ||
|
|
818225c60e | ||
|
|
d78a1fe9c6 | ||
|
|
05b3e594b5 | ||
|
|
5a4d007cf9 | ||
|
|
3b25a2dd84 | ||
|
|
baee4c5f22 | ||
|
|
5e32f9d922 | ||
|
|
1454e7e07d | ||
|
|
6848337445 | ||
|
|
519fbd897e | ||
|
|
217569104b | ||
|
|
4c184bb7f0 | ||
|
|
a222fae7c8 | ||
|
|
94788cda53 | ||
|
|
fb931ee4de | ||
|
|
bc2c56dfb6 | ||
|
|
ae37f01f62 | ||
|
|
ef31e14518 | ||
|
|
9b0cba367e | ||
|
|
48ac690a70 | ||
|
|
bfa4fbd691 | ||
|
|
58fdc86d41 | ||
|
|
6ff452a2e1 | ||
|
|
e9b892301b | ||
|
|
a202e2bf9d | ||
|
|
3bc4e0d12f | ||
|
|
2fc41cd5df | ||
|
|
8c42ff2ff8 | ||
|
|
6ccb3f085a | ||
|
|
a0a1b431be | ||
|
|
f137fc78a6 | ||
|
|
396f096dda | ||
|
|
e04b2d6ff3 | ||
|
|
cbd8b094bd | ||
|
|
5c7487e91f | ||
|
|
477f8eeb68 | ||
|
|
737e37170d | ||
|
|
c58a7ef819 | ||
|
|
bd08e6d787 | ||
|
|
47e6192b99 | ||
|
|
d1e9760b92 | ||
|
|
7153cb09f1 | ||
|
|
29f5f4edfa | ||
|
|
b469a7eff4 | ||
|
|
78153e5012 | ||
|
|
b1ee1efecb | ||
|
|
526932a7f6 | ||
|
|
6889152d81 | ||
|
|
4affc259a6 | ||
|
|
0ec065f1fb | ||
|
|
8eb4320f76 | ||
|
|
1c12ab31f9 | ||
|
|
49fd76b336 | ||
|
|
5854b39dd4 | ||
|
|
c0271a948a | ||
|
|
aff4ee5ebf | ||
|
|
675d2f3539 | ||
|
|
2974b57ef4 | ||
|
|
679bdd5e04 | ||
|
|
e6cb47fcb8 | ||
|
|
a514818e13 | ||
|
|
89021cde90 | ||
|
|
32ecc282a2 | ||
|
|
59b1d4673f | ||
|
|
ec0c655c8d | ||
|
|
42a0f45a96 | ||
|
|
125e5eaab1 | ||
|
|
f2dab9ba89 | ||
|
|
02a068a68b | ||
|
|
91f0650071 | ||
|
|
b97819189b | ||
|
|
b928201397 | ||
|
|
b500c914b0 | ||
|
|
4b0d22fae3 | ||
|
|
b46c09ac6c | ||
|
|
3ce8923086 | ||
|
|
7ac6d3ed50 | ||
|
|
3cd057d7a2 | ||
|
|
4834ee6223 | ||
|
|
cb85be41b1 | ||
|
|
eb227c0acc | ||
|
|
25a57e2292 | ||
|
|
3f3b04a4ee | ||
|
|
3f6de7968a | ||
|
|
024207e2d9 | ||
|
|
8f7db9212c | ||
|
|
b1e9e03aa4 | ||
|
|
87a53d6d80 | ||
|
|
59c65a4192 | ||
|
|
c984c6c7f2 | ||
|
|
9a3ce504bc | ||
|
|
16265d27f5 | ||
|
|
570fe43efb | ||
|
|
506a9f1b94 | ||
|
|
a067b32467 | ||
|
|
9b6e51b4fe | ||
|
|
e23dd0a3fa | ||
|
|
71304e4228 | ||
|
|
2adeaaeded | ||
|
|
a96728ff4d | ||
|
|
eaffdee0dc | ||
|
|
feaa3b653f | ||
|
|
9438f9df05 | ||
|
|
b90e0834a5 | ||
|
|
29440f5482 | ||
|
|
5a95a5c9fd | ||
|
|
118e8afbef | ||
|
|
8342168658 | ||
|
|
d5661baf98 | ||
|
|
95fcc0019c | ||
|
|
0ccd83e809 | ||
|
|
732861a940 | ||
|
|
d53dd1e356 | ||
|
|
1a2760edee | ||
|
|
23ae4547ca | ||
|
|
385b344a43 | ||
|
|
a340529de3 | ||
|
|
4a0b2a6c09 | ||
|
|
756a1cbf8f | ||
|
|
8af4f1da8e | ||
|
|
4b82440915 | ||
|
|
bb6d55783e | ||
|
|
2b8cd63b34 | ||
|
|
b0c3098693 | ||
|
|
2517aa39b2 | ||
|
|
ceaaa05af0 | ||
|
|
3b13380051 | ||
|
|
ef6e6f9556 | ||
|
|
0a6808c4c1 | ||
|
|
6442c56d82 | ||
|
|
e191e514b9 | ||
|
|
f33a2ffb01 | ||
|
|
0578c31522 | ||
|
|
8cbdc6d8fe | ||
|
|
60fb06da4e | ||
|
|
55ed6e2294 | ||
|
|
42780d5f97 | ||
|
|
f050d281fd | ||
|
|
3ca4d532b4 | ||
|
|
e3e855c526 | ||
|
|
23bf50b90a | ||
|
|
c43c2320e7 | ||
|
|
01e6e9a2ba | ||
|
|
bd3b1943c4 | ||
|
|
1dbf561db0 | ||
|
|
a43a6627eb | ||
|
|
5bff8bc8ce | ||
|
|
7879ba6a77 | ||
|
|
a63b341913 | ||
|
|
c062097b2a | ||
|
|
48e42af8e7 | ||
|
|
6c7f8eaefb | ||
|
|
3d99ad7bc4 | ||
|
|
8fea571f6e | ||
|
|
d70bbcc2ce | ||
|
|
73769c6cae | ||
|
|
7e98936c58 | ||
|
|
4e17fc06ff | ||
|
|
ff4df6f3bf | ||
|
|
91b929d466 | ||
|
|
6bef5ca7a4 | ||
|
|
4817fa0bd1 | ||
|
|
da4a086398 | ||
|
|
69e8c5f0fc | ||
|
|
12d1186888 | ||
|
|
325892a21c | ||
|
|
18d92559b5 | ||
|
|
f2aeeb7b3c | ||
|
|
110c9f7e1b | ||
|
|
1a22af4f27 | ||
|
|
efa32a8c04 | ||
|
|
9bad12968f | ||
|
|
f1d96343a9 | ||
|
|
0496ec3bb8 | ||
|
|
568f927b9b | ||
|
|
f842e15d64 | ||
|
|
3a07093663 | ||
|
|
1fe966d0f7 | ||
|
|
812172f1bd | ||
|
|
9e9bd440f4 | ||
|
|
7487b15522 | ||
|
|
de5ce8a613 | ||
|
|
8c9577aa95 | ||
|
|
4baf3dc484 | ||
|
|
50ef5115e7 | ||
|
|
a2247363af | ||
|
|
a0af8ee91c | ||
|
|
25f6543443 | ||
|
|
d52a0b96ac | ||
|
|
f14b282f0f | ||
|
|
7d494cd65e | ||
|
|
139374966f | ||
|
|
bf06710215 | ||
|
|
d4e0d0db05 | ||
|
|
f96a3ee29a | ||
|
|
3bf6b77319 | ||
|
|
3b3b0c8a87 | ||
|
|
aa8cb44a33 | ||
|
|
fc60fd0322 | ||
|
|
46402a97c7 | ||
|
|
5bf6a47948 | ||
|
|
2d8486bac4 | ||
|
|
eea6f2749a | ||
|
|
5e9b2e41ae | ||
|
|
2bbe20edc3 | ||
|
|
db2004542e | ||
|
|
ddbfc65ad0 | ||
|
|
982040c792 | ||
|
|
4b0a4a2741 | ||
|
|
28ba01b361 | ||
|
|
d32d1c6079 | ||
|
|
dd494d2daa | ||
|
|
eb6dbf49a1 | ||
|
|
e5fa411092 | ||
|
|
1ced8924b3 | ||
|
|
3c3900fac6 | ||
|
|
3b298e19bc | ||
|
|
71eafe04a8 | ||
|
|
80d248e02d | ||
|
|
2032fb10da | ||
|
|
ca1f176c61 | ||
|
|
3ced9bc28b | ||
|
|
deea9c8c3c | ||
|
|
4e47c81ed8 | ||
|
|
00cee71c18 | ||
|
|
470c4d15dd | ||
|
|
50bacc03b3 | ||
|
|
dd260140b2 | ||
|
|
8aa82be12a | ||
|
|
b7f9e431a5 | ||
|
|
b9bd2ea4e2 | ||
|
|
e4c93bed8b | ||
|
|
4fd6e36c2f | ||
|
|
715359c120 | ||
|
|
6f018d75ee | ||
|
|
fd947aadea | ||
|
|
e061ba2b93 | ||
|
|
87bccc13cc | ||
|
|
3a950721b9 | ||
|
|
569639eb90 | ||
|
|
68cb1f3409 | ||
|
|
11da0d9889 | ||
|
|
6a7e2a8036 | ||
|
|
035f83c464 | ||
|
|
3c34ddcc4f | ||
|
|
bbee2865e9 | ||
|
|
a82cac5361 | ||
|
|
83e5cb2d2f | ||
|
|
a5d2f0d9ac | ||
|
|
d3cf18160e | ||
|
|
618e4addd8 | ||
|
|
69f16cc972 | ||
|
|
2676d40065 | ||
|
|
b64545c7c7 | ||
|
|
7bc8554e01 | ||
|
|
5232aeacad | ||
|
|
261150e81a | ||
|
|
3e0d24a3f6 | ||
|
|
ffe8ac168f | ||
|
|
17b280e59e | ||
|
|
5edba4a7f3 | ||
|
|
d842fed37e | ||
|
|
14981162fd | ||
|
|
288daa4e90 | ||
|
|
30e8fb12e4 | ||
|
|
d8578bc1cb | ||
|
|
5e21dc6cb3 | ||
|
|
39b3a503b4 | ||
|
|
a70d472b5c | ||
|
|
0ed2886ad0 | ||
|
|
6b31e2f622 | ||
|
|
aabf8a99bc | ||
|
|
7ccfe85ee5 | ||
|
|
95701db1bd | ||
|
|
24105254ac | ||
|
|
4fe99d05fd | ||
|
|
d35f93b233 | ||
|
|
766b0f35df | ||
|
|
a0470a96eb | ||
|
|
b82123563b | ||
|
|
787e25cd78 | ||
|
|
c6375f8abf | ||
|
|
58e5deba01 | ||
|
|
028e877342 | ||
|
|
47bff2b6a9 | ||
|
|
1502bcea12 | ||
|
|
2701f83634 | ||
|
|
601037abb5 | ||
|
|
7e9b12403a | ||
|
|
d903e5912a | ||
|
|
d2aea63573 | ||
|
|
57b4639709 | ||
|
|
1308b6cbe8 | ||
|
|
98abd7d3fa | ||
|
|
e4180cefba | ||
|
|
f67b5356fa | ||
|
|
9bdb581220 | ||
|
|
42d6d935ae | ||
|
|
8d62b992ef | ||
|
|
2ad86aa9a6 | ||
|
|
74a472ece7 | ||
|
|
b2ce848b53 | ||
|
|
519ec20d05 | ||
|
|
3b1e26d0d4 | ||
|
|
118d2b52e6 | ||
|
|
e625884702 | ||
|
|
fa78f50fe3 | ||
|
|
05ab94945b | ||
|
|
7a64a25ff4 | ||
|
|
7f10494bbe | ||
|
|
f2d4024783 | ||
|
|
70795a4047 | ||
|
|
d8a17a7238 | ||
|
|
cbf98c0128 | ||
|
|
a5fe5e136b | ||
|
|
d6863ec775 | ||
|
|
b12c51f56c | ||
|
|
b9561fc46c | ||
|
|
9b19990764 | ||
|
|
5d6a18f358 | ||
|
|
3c37764974 | ||
|
|
6551d6bc87 | ||
|
|
2a1bb4ac41 | ||
|
|
5d653e7c19 | ||
|
|
68c959d8ef | ||
|
|
ba771483d8 | ||
|
|
a2d8e815f6 | ||
|
|
b1e05bb909 | ||
|
|
ccb16b7484 | ||
|
|
1613a8ba4f | ||
|
|
e94ffbc2a1 | ||
|
|
32f220e02c | ||
|
|
69c60feda4 | ||
|
|
a215ea9143 | ||
|
|
f81a42b4e8 | ||
|
|
b095e17827 | ||
|
|
2a758ae33f | ||
|
|
3e58cf2667 | ||
|
|
b9c29f2a36 | ||
|
|
647adb9ba0 | ||
|
|
7d6d73529b | ||
|
|
420476ad92 | ||
|
|
4ca7325d1a | ||
|
|
8ddd95d0d4 | ||
|
|
1378364686 | ||
|
|
cc4953b560 | ||
|
|
fe3eae3680 | ||
|
|
2a7a22d953 | ||
|
|
f163b798ea | ||
|
|
d4563b8693 | ||
|
|
a54ed77140 | ||
|
|
f27979ef7f | ||
|
|
122a9af9b3 | ||
|
|
32a97e5479 | ||
|
|
bf30dab9c4 | ||
|
|
342bb9f685 | ||
|
|
b25668c83a | ||
|
|
a72bd31f5d | ||
|
|
896e716d02 | ||
|
|
eec3ce8162 | ||
|
|
2761a837c6 | ||
|
|
da43abe644 | ||
|
|
af953ff8a3 | ||
|
|
6fc52c81ab | ||
|
|
1ad2128b2a | ||
|
|
880c42ad41 | ||
|
|
c9e0d77c93 | ||
|
|
7a750dc2ca | ||
|
|
44b70a87df | ||
|
|
a05addec19 | ||
|
|
8a4d762798 | ||
|
|
c9a420ec49 | ||
|
|
beccca5fa2 | ||
|
|
66d8b8bb10 | ||
|
|
76ca650972 | ||
|
|
eb70699c0b | ||
|
|
b401f83eb6 | ||
|
|
993a1a6caf | ||
|
|
c3481c7356 | ||
|
|
3b7695539f | ||
|
|
b1957737f2 | ||
|
|
5f462056f6 | ||
|
|
0de4d61b6d | ||
|
|
7a28a5c216 | ||
|
|
d8aa21ca3a | ||
|
|
c4323573d2 | ||
|
|
46cfaa96b7 | ||
|
|
a610b6bd8d | ||
|
|
cb66aadd80 | ||
|
|
9ea2ae267e | ||
|
|
7d86b28335 | ||
|
|
4f8e48df7c | ||
|
|
d96d2fc6e9 | ||
|
|
b6dd999c1b | ||
|
|
9a09222b7d | ||
|
|
be3cfdd4a6 | ||
|
|
f5bdf9d2c9 | ||
|
|
6afd27f9c9 | ||
|
|
ccef350287 | ||
|
|
4400a945e3 | ||
|
|
384a38418b | ||
|
|
2163a138ed | ||
|
|
b6c2ecfecb | ||
|
|
ac182c74b3 | ||
|
|
8e25c3c412 | ||
|
|
962240031f |
7
.github/pull_request_template.md
vendored
7
.github/pull_request_template.md
vendored
@@ -1,11 +1,14 @@
|
||||
## Description
|
||||
|
||||
[Provide a brief description of the changes in this PR]
|
||||
|
||||
|
||||
## How Has This Been Tested?
|
||||
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
@@ -65,8 +65,11 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -12,7 +12,32 @@ env:
|
||||
BUILDKIT_PROGRESS: plain
|
||||
|
||||
jobs:
|
||||
# 1) Preliminary job to check if the changed files are relevant
|
||||
check_model_server_changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: ${{ steps.check.outputs.changed }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check if relevant files changed
|
||||
id: check
|
||||
run: |
|
||||
# Default to "false"
|
||||
echo "changed=false" >> $GITHUB_OUTPUT
|
||||
|
||||
# Compare the previous commit (github.event.before) to the current one (github.sha)
|
||||
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
|
||||
# set changed=true
|
||||
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
|
||||
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
|
||||
echo "changed=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
build-amd64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
|
||||
steps:
|
||||
@@ -52,6 +77,8 @@ jobs:
|
||||
provenance: false
|
||||
|
||||
build-arm64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
steps:
|
||||
@@ -91,7 +118,8 @@ jobs:
|
||||
provenance: false
|
||||
|
||||
merge-and-scan:
|
||||
needs: [build-amd64, build-arm64]
|
||||
needs: [build-amd64, build-arm64, check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
|
||||
@@ -60,6 +60,8 @@ jobs:
|
||||
push: true
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
22
.github/workflows/pr-helm-chart-testing.yml
vendored
22
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -21,10 +21,10 @@ jobs:
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
version: v3.17.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
uses: helm/chart-testing-action@v2.7.0
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
@@ -37,22 +37,6 @@ jobs:
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -62,7 +46,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
uses: helm/kind-action@v1.12.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
67
.github/workflows/pr-integration-tests.yml
vendored
67
.github/workflows/pr-integration-tests.yml
vendored
@@ -94,23 +94,27 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=basic \
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Waiting for 3 minutes to ensure API server is ready..."
|
||||
sleep 180
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
@@ -119,6 +123,10 @@ jobs:
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e IMAGE_TAG=test \
|
||||
-e DEV_MODE=true \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
@@ -126,34 +134,37 @@ jobs:
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
|
||||
echo "Multi-tenant integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
echo "All multi-tenant integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
@@ -183,15 +194,24 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
@@ -201,6 +221,8 @@ jobs:
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
@@ -216,27 +238,30 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Stop Docker containers
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-logs
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
|
||||
29
.github/workflows/pr-linear-check.yml
vendored
Normal file
29
.github/workflows/pr-linear-check.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: Ensure PR references Linear
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
linear-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check PR body for Linear link or override
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
# Looking for "https://linear.app" in the body
|
||||
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
|
||||
echo "Found a Linear link. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Looking for a checked override: "[x] Override Linear Check"
|
||||
if echo "$PR_BODY" | grep -q "\[x\].*Override Linear Check"; then
|
||||
echo "Override box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Otherwise, fail the run
|
||||
echo "No Linear link or override found in the PR description."
|
||||
exit 1
|
||||
@@ -1,6 +1,6 @@
|
||||
name: Run Chromatic Tests
|
||||
name: Run Playwright Tests
|
||||
concurrency:
|
||||
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
group: Run-Playwright-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on: push
|
||||
@@ -8,6 +8,8 @@ on: push
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MOCK_LLM_RESPONSE: true
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
@@ -196,43 +198,47 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
chromatic-tests:
|
||||
name: Chromatic Tests
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
# Chromatic may be reintroduced in the future for UI diff testing if needed.
|
||||
|
||||
needs: playwright-tests
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
# chromatic-tests:
|
||||
# name: Chromatic Tests
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
# needs: playwright-tests
|
||||
# runs-on:
|
||||
# [
|
||||
# runs-on,
|
||||
# runner=32cpu-linux-x64,
|
||||
# disk=large,
|
||||
# "run-id=${{ github.run_id }}",
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
# - name: Setup node
|
||||
# uses: actions/setup-node@v4
|
||||
# with:
|
||||
# node-version: 22
|
||||
|
||||
- name: Download Playwright test results
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: test-results
|
||||
path: ./web/test-results
|
||||
# - name: Install node dependencies
|
||||
# working-directory: ./web
|
||||
# run: npm ci
|
||||
|
||||
- name: Run Chromatic
|
||||
uses: chromaui/action@latest
|
||||
with:
|
||||
playwright: true
|
||||
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
workingDir: ./web
|
||||
env:
|
||||
CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
# - name: Download Playwright test results
|
||||
# uses: actions/download-artifact@v4
|
||||
# with:
|
||||
# name: test-results
|
||||
# path: ./web/test-results
|
||||
|
||||
# - name: Run Chromatic
|
||||
# uses: chromaui/action@latest
|
||||
# with:
|
||||
# playwright: true
|
||||
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
# workingDir: ./web
|
||||
# env:
|
||||
# CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
13
.github/workflows/pr-python-connector-tests.yml
vendored
13
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -39,6 +39,15 @@ env:
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
@@ -65,7 +74,9 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -7,4 +7,6 @@
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
/web/test-results/
|
||||
/web/test-results/
|
||||
backend/onyx/agent_search/main/test_data.json
|
||||
backend/tests/regression/answer_quality/test_data.json
|
||||
|
||||
7
.vscode/env_template.txt
vendored
7
.vscode/env_template.txt
vendored
@@ -29,6 +29,7 @@ REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
OPENAI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
@@ -51,3 +52,9 @@ BING_API_KEY=<REPLACE THIS>
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
# Agent Search configs # TODO: Remove give proper namings
|
||||
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
|
||||
AGENT_RERANKING_STATS=True
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
|
||||
31
.vscode/launch.template.jsonc
vendored
31
.vscode/launch.template.jsonc
vendored
@@ -28,6 +28,7 @@
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
@@ -51,7 +52,8 @@
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat"
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
@@ -203,7 +205,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
@@ -269,6 +271,31 @@
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
|
||||
@@ -17,9 +17,10 @@ Before starting, make sure the Docker Daemon is running.
|
||||
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
|
||||
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
|
||||
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
|
||||
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
4. CD into web, run "npm i" followed by npm run dev.
|
||||
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
7. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
## Features
|
||||
|
||||
|
||||
122
README.md
122
README.md
@@ -24,112 +24,94 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
|
||||
configuring AI Assistants.
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
|
||||
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
|
||||
There are over 40 supported connectors such as Google Drive, Slack, Confluence, Salesforce, etc. which keep knowledge and permissions up to date.
|
||||
Create custom AI agents with unique prompts, knowledge, and actions the agents can take.
|
||||
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
|
||||
|
||||
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
|
||||
supported?" or "Where's the pull request for feature Y?"
|
||||
|
||||
<h3>Usage</h3>
|
||||
<h3>Feature Showcase</h3>
|
||||
|
||||
Onyx Web App:
|
||||
**Deep research over your team's knowledge:**
|
||||
|
||||
https://github.com/onyx-dot-app/onyx/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95d0-4fb5-8650-a396e05e0a32.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk5Mjg2MzYsIm5iZiI6MTczOTkyODMzNiwicGF0aCI6Ii8zMjUyMDc2OS80MTQ1MDkzMTItNDgzOTJlODMtOTVkMC00ZmI1LTg2NTAtYTM5NmUwNWUwYTMyLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE5VDAxMjUzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMzk5Njg2Y2Y5YjFmNDNiYTQ2YzM5ZTg5YWJiYTU2NWMyY2YwNmUyODE2NWUxMDRiMWQxZWJmODI4YTA0MTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.a9D8A0sgKE9AoaoE-mfFbJ6_OKYeqaf7TZ4Han2JfW8
|
||||
|
||||
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
|
||||
|
||||
https://github.com/onyx-dot-app/onyx/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
|
||||
**Use Onyx as a secure AI Chat with any LLM:**
|
||||
|
||||

|
||||
|
||||
|
||||
**Easily set up connectors to your apps:**
|
||||
|
||||

|
||||
|
||||
|
||||
**Access Onyx where your team already works:**
|
||||
|
||||

|
||||
|
||||
For more details on the Admin UI to manage connectors and users, check out our
|
||||
<strong><a href="https://www.youtube.com/watch?v=geNzY1nbCnU">Full Video Demo</a></strong>!
|
||||
|
||||
## Deployment
|
||||
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
|
||||
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
|
||||
We also have built-in support for high-availability/scalable deployment on Kubernetes.
|
||||
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
|
||||
|
||||
## 💃 Main Features
|
||||
|
||||
- Chat UI with the ability to select documents to chat with.
|
||||
- Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
- Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
|
||||
- Document Search + AI Answers for natural language queries.
|
||||
- Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
- Slack integration to get answers and search results directly in Slack.
|
||||
|
||||
## 🚧 Roadmap
|
||||
|
||||
- Chat/Prompt sharing with specific teammates and user groups.
|
||||
- Multimodal model support, chat with images, video etc.
|
||||
- Choosing between LLMs and parameters during chat session.
|
||||
- Tool calling and agent configurations options.
|
||||
- Extensions to the Chrome Plugin
|
||||
- Latest methods in information retrieval (StructRAG, LightGraphRAG, etc.)
|
||||
- Personalized Search
|
||||
- Organizational understanding and ability to locate and suggest experts from your team.
|
||||
- Code Search
|
||||
- SQL and Structured Query Language
|
||||
|
||||
## Other Notable Benefits of Onyx
|
||||
|
||||
- User Authentication with document level access management.
|
||||
- Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
- Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
- Custom deep learning models + learn from user feedback.
|
||||
- Easy deployment and ability to host Onyx anywhere of your choosing.
|
||||
## 🔍 Other Notable Benefits of Onyx
|
||||
- Custom deep learning models only through Onyx + learn from user feedback.
|
||||
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- Knowledge curation features like document-sets, query history, usage analytics, etc.
|
||||
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
|
||||
|
||||
|
||||
## 🔌 Connectors
|
||||
Keep knowledge and access up to sync across 40+ connectors:
|
||||
|
||||
Efficiently pulls the latest changes from:
|
||||
|
||||
- Slack
|
||||
- GitHub
|
||||
- Google Drive
|
||||
- Confluence
|
||||
- Slack
|
||||
- Gmail
|
||||
- Salesforce
|
||||
- Microsoft Sharepoint
|
||||
- Github
|
||||
- Jira
|
||||
- Zendesk
|
||||
- Gmail
|
||||
- Notion
|
||||
- Gong
|
||||
- Slab
|
||||
- Linear
|
||||
- Productboard
|
||||
- Guru
|
||||
- Bookstack
|
||||
- Document360
|
||||
- Sharepoint
|
||||
- Hubspot
|
||||
- Microsoft Teams
|
||||
- Dropbox
|
||||
- Local Files
|
||||
- Websites
|
||||
- And more ...
|
||||
|
||||
## 📚 Editions
|
||||
See the full list [here](https://docs.onyx.app/connectors).
|
||||
|
||||
|
||||
## 📚 Licensing
|
||||
There are two editions of Onyx:
|
||||
|
||||
- Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
- Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
- Role-based access control
|
||||
- Document permission inheritance from connected sources
|
||||
- Usage analytics and query history accessible to admins
|
||||
- Whitelabeling
|
||||
- API key authentication
|
||||
- Encryption of secrets
|
||||
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
- Onyx Community Edition (CE) is available freely under the MIT Expat license. Simply follow the Deployment guide above.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
|
||||
For feature details, check out [our website](https://www.onyx.app/pricing).
|
||||
|
||||
To try the Onyx Enterprise Edition:
|
||||
1. Checkout [Onyx Cloud](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting the Enterprise Edition, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
|
||||
|
||||
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
@@ -9,8 +9,10 @@ founders@onyx.app for more information. Please visit https://github.com/onyx-dot
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true"
|
||||
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
@@ -26,14 +28,16 @@ RUN apt-get update && \
|
||||
curl \
|
||||
zip \
|
||||
ca-certificates \
|
||||
libgnutls30=3.7.9-2+deb12u3 \
|
||||
libblkid1=2.38.1-5+deb12u1 \
|
||||
libmount1=2.38.1-5+deb12u1 \
|
||||
libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 \
|
||||
libgnutls30 \
|
||||
libblkid1 \
|
||||
libmount1 \
|
||||
libsmartcols1 \
|
||||
libuuid1 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc && \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -99,7 +103,8 @@ COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
# Escape hatch
|
||||
# Escape hatch scripts
|
||||
COPY ./scripts/debugging /app/scripts/debugging
|
||||
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
|
||||
# Put logo in assets
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add shortcut option for users
|
||||
|
||||
Revision ID: 027381bce97c
|
||||
Revises: 6fc7886d665d
|
||||
Create Date: 2025-01-14 12:14:00.814390
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "027381bce97c"
|
||||
down_revision = "6fc7886d665d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"shortcut_enabled", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "shortcut_enabled")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add index to index_attempt.time_created
|
||||
|
||||
Revision ID: 0f7ff6d75b57
|
||||
Revises: 369644546676
|
||||
Create Date: 2025-01-10 14:01:14.067144
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0f7ff6d75b57"
|
||||
down_revision = "fec3db967bf7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
op.f("ix_index_attempt_status"),
|
||||
"index_attempt",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_index_attempt_time_created"),
|
||||
"index_attempt",
|
||||
["time_created"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_index_attempt_time_created"), table_name="index_attempt")
|
||||
|
||||
op.drop_index(op.f("ix_index_attempt_status"), table_name="index_attempt")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add indexes to document__tag
|
||||
|
||||
Revision ID: 1a03d2c2856b
|
||||
Revises: 9c00a2bccb83
|
||||
Create Date: 2025-02-18 10:45:13.957807
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1a03d2c2856b"
|
||||
down_revision = "9c00a2bccb83"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
op.f("ix_document__tag_tag_id"),
|
||||
"document__tag",
|
||||
["tag_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_document__tag_tag_id"), table_name="document__tag")
|
||||
@@ -0,0 +1,32 @@
|
||||
"""set built in to default
|
||||
|
||||
Revision ID: 2cdeff6d8c93
|
||||
Revises: f5437cc136c5
|
||||
Create Date: 2025-02-11 14:57:51.308775
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2cdeff6d8c93"
|
||||
down_revision = "f5437cc136c5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Prior to this migration / point in the codebase history,
|
||||
# built in personas were implicitly treated as default personas (with no option to change this)
|
||||
# This migration makes that explicit
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = TRUE
|
||||
WHERE builtin_persona = TRUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add chat session specific temperature override
|
||||
|
||||
Revision ID: 2f80c6a2550f
|
||||
Revises: 33ea50e88f24
|
||||
Create Date: 2025-01-31 10:30:27.289646
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2f80c6a2550f"
|
||||
down_revision = "33ea50e88f24"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"temperature_override_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "temperature_override")
|
||||
op.drop_column("user", "temperature_override_enabled")
|
||||
@@ -0,0 +1,80 @@
|
||||
"""foreign key input prompts
|
||||
|
||||
Revision ID: 33ea50e88f24
|
||||
Revises: a6df6b88ef81
|
||||
Create Date: 2025-01-29 10:54:22.141765
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33ea50e88f24"
|
||||
down_revision = "a6df6b88ef81"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Safely drop constraints if exists
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
|
||||
"""
|
||||
)
|
||||
|
||||
# Recreate with ON DELETE CASCADE
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the new FKs with ondelete
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate them without cascading
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,59 @@
|
||||
"""add back input prompts
|
||||
|
||||
Revision ID: 3c6531f32351
|
||||
Revises: aeda5f2df4f6
|
||||
Create Date: 2025-01-13 12:49:51.705235
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3c6531f32351"
|
||||
down_revision = "aeda5f2df4f6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
|
||||
),
|
||||
sa.Column("disabled", sa.Boolean(), nullable=False, default=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
@@ -40,6 +40,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
|
||||
op.drop_constraint("persona_category_id_fkey", "persona", type_="foreignkey")
|
||||
op.drop_column("persona", "category_id")
|
||||
op.drop_table("persona_category")
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""lowercase_user_emails
|
||||
|
||||
Revision ID: 4d58345da04a
|
||||
Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4d58345da04a"
|
||||
down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Cannot restore original case of emails
|
||||
pass
|
||||
@@ -5,7 +5,6 @@ Revises: 47e5bef3a1d7
|
||||
Create Date: 2024-11-06 13:15:53.302644
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
@@ -20,13 +19,8 @@ down_revision = "47e5bef3a1d7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
logger.info(f"{revision}: create_table: slack_bot")
|
||||
# Create new slack_bot table
|
||||
op.create_table(
|
||||
"slack_bot",
|
||||
@@ -63,7 +57,6 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
# Handle existing Slack bot tokens first
|
||||
logger.info(f"{revision}: Checking for existing Slack bot.")
|
||||
bot_token = None
|
||||
app_token = None
|
||||
first_row_id = None
|
||||
@@ -71,15 +64,12 @@ def upgrade() -> None:
|
||||
try:
|
||||
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
|
||||
except Exception:
|
||||
logger.warning("No existing Slack bot tokens found.")
|
||||
tokens = {}
|
||||
|
||||
bot_token = tokens.get("bot_token")
|
||||
app_token = tokens.get("app_token")
|
||||
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Found bot and app tokens.")
|
||||
|
||||
session = Session(bind=op.get_bind())
|
||||
new_slack_bot = SlackBot(
|
||||
name="Slack Bot (Migrated)",
|
||||
@@ -170,10 +160,9 @@ def upgrade() -> None:
|
||||
# Clean up old tokens if they existed
|
||||
try:
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Removing old bot and app tokens.")
|
||||
get_kv_store().delete("slack_bot_tokens_config_key")
|
||||
except Exception:
|
||||
logger.warning("tried to delete tokens in dynamic config but failed")
|
||||
pass
|
||||
# Rename the table
|
||||
op.rename_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
@@ -190,8 +179,6 @@ def upgrade() -> None:
|
||||
# Drop the table with CASCADE to handle dependent objects
|
||||
op.execute("DROP TABLE slack_bot_config CASCADE")
|
||||
|
||||
logger.info(f"{revision}: Migration complete.")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Recreate the old slack_bot_config table
|
||||
@@ -273,7 +260,7 @@ def downgrade() -> None:
|
||||
}
|
||||
get_kv_store().store("slack_bot_tokens_config_key", tokens)
|
||||
except Exception:
|
||||
logger.warning("Failed to save tokens back to KV store")
|
||||
pass
|
||||
|
||||
# Drop the new tables in reverse order
|
||||
op.drop_table("slack_channel_config")
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""make categories labels and many to many
|
||||
|
||||
Revision ID: 6fc7886d665d
|
||||
Revises: 3c6531f32351
|
||||
Create Date: 2025-01-13 18:12:18.029112
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6fc7886d665d"
|
||||
down_revision = "3c6531f32351"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename persona_category table to persona_label
|
||||
op.rename_table("persona_category", "persona_label")
|
||||
|
||||
# Create the new association table
|
||||
op.create_table(
|
||||
"persona__persona_label",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("persona_label_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_label_id"],
|
||||
["persona_label.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "persona_label_id"),
|
||||
)
|
||||
|
||||
# Copy existing relationships to the new table
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO persona__persona_label (persona_id, persona_label_id)
|
||||
SELECT id, category_id FROM persona WHERE category_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Remove the old category_id column from persona table
|
||||
op.drop_column("persona", "category_id")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Rename persona_label table back to persona_category
|
||||
op.rename_table("persona_label", "persona_category")
|
||||
|
||||
# Add back the category_id column to persona table
|
||||
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"persona_category_id_fkey",
|
||||
"persona",
|
||||
"persona_category",
|
||||
["category_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Copy the first label relationship back to the persona table
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET category_id = (
|
||||
SELECT persona_label_id
|
||||
FROM persona__persona_label
|
||||
WHERE persona__persona_label.persona_id = persona.id
|
||||
LIMIT 1
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the association table
|
||||
op.drop_table("persona__persona_label")
|
||||
72
backend/alembic/versions/97dbb53fa8c8_add_syncrecord.py
Normal file
72
backend/alembic/versions/97dbb53fa8c8_add_syncrecord.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Add SyncRecord
|
||||
|
||||
Revision ID: 97dbb53fa8c8
|
||||
Revises: 369644546676
|
||||
Create Date: 2025-01-11 19:39:50.426302
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "97dbb53fa8c8"
|
||||
down_revision = "be2ab2aa50ee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"sync_record",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("entity_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"sync_type",
|
||||
sa.Enum(
|
||||
"DOCUMENT_SET",
|
||||
"USER_GROUP",
|
||||
"CONNECTOR_DELETION",
|
||||
name="synctype",
|
||||
native_enum=False,
|
||||
length=40,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"sync_status",
|
||||
sa.Enum(
|
||||
"IN_PROGRESS",
|
||||
"SUCCESS",
|
||||
"FAILED",
|
||||
"CANCELED",
|
||||
name="syncstatus",
|
||||
native_enum=False,
|
||||
length=40,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("num_docs_synced", sa.Integer(), nullable=False),
|
||||
sa.Column("sync_start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("sync_end_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Add index for fetch_latest_sync_record query
|
||||
op.create_index(
|
||||
"ix_sync_record_entity_id_sync_type_sync_start_time",
|
||||
"sync_record",
|
||||
["entity_id", "sync_type", "sync_start_time"],
|
||||
)
|
||||
|
||||
# Add index for cleanup_sync_records query
|
||||
op.create_index(
|
||||
"ix_sync_record_entity_id_sync_type_sync_status",
|
||||
"sync_record",
|
||||
["entity_id", "sync_type", "sync_status"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_sync_record_entity_id_sync_type_sync_status")
|
||||
op.drop_index("ix_sync_record_entity_id_sync_type_sync_start_time")
|
||||
op.drop_table("sync_record")
|
||||
107
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
107
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""agent_tracking
|
||||
|
||||
Revision ID: 98a5008d8711
|
||||
Revises: 2f80c6a2550f
|
||||
Create Date: 2025-01-29 17:00:00.000001
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "98a5008d8711"
|
||||
down_revision = "2f80c6a2550f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"agent__search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create sub_question table
|
||||
op.create_table(
|
||||
"agent__sub_question",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_question", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
sa.Column("sub_answer", sa.Text),
|
||||
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("level", sa.Integer(), nullable=False),
|
||||
sa.Column("level_question_num", sa.Integer(), nullable=False),
|
||||
)
|
||||
|
||||
# Create sub_query table
|
||||
op.create_table(
|
||||
"agent__sub_query",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column(
|
||||
"parent_question_id", sa.Integer, sa.ForeignKey("agent__sub_question.id")
|
||||
),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_query", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create sub_query__search_doc association table
|
||||
op.create_table(
|
||||
"agent__sub_query__search_doc",
|
||||
sa.Column(
|
||||
"sub_query_id",
|
||||
sa.Integer,
|
||||
sa.ForeignKey("agent__sub_query.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"search_doc_id",
|
||||
sa.Integer,
|
||||
sa.ForeignKey("search_doc.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"refined_answer_improvement",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "refined_answer_improvement")
|
||||
op.drop_table("agent__sub_query__search_doc")
|
||||
op.drop_table("agent__sub_query")
|
||||
op.drop_table("agent__sub_question")
|
||||
op.drop_table("agent__search_metrics")
|
||||
@@ -0,0 +1,43 @@
|
||||
"""chat_message_agentic
|
||||
|
||||
Revision ID: 9c00a2bccb83
|
||||
Revises: b7a7eee5aa15
|
||||
Create Date: 2025-02-17 11:15:43.081150
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9c00a2bccb83"
|
||||
down_revision = "b7a7eee5aa15"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First add the column as nullable
|
||||
op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True))
|
||||
|
||||
# Update existing rows based on presence of SubQuestions
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET is_agentic = EXISTS (
|
||||
SELECT 1
|
||||
FROM agent__sub_question
|
||||
WHERE agent__sub_question.primary_question_id = chat_message.id
|
||||
)
|
||||
WHERE is_agentic IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the column non-nullable with a default value of False
|
||||
op.alter_column(
|
||||
"chat_message", "is_agentic", nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "is_agentic")
|
||||
@@ -0,0 +1,29 @@
|
||||
"""remove recent assistants
|
||||
|
||||
Revision ID: a6df6b88ef81
|
||||
Revises: 4d58345da04a
|
||||
Create Date: 2025-01-29 10:25:52.790407
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a6df6b88ef81"
|
||||
down_revision = "4d58345da04a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
"""remove inactive ccpair status on downgrade
|
||||
|
||||
Revision ID: acaab4ef4507
|
||||
Revises: b388730a2899
|
||||
Create Date: 2025-02-16 18:21:41.330212
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from sqlalchemy import update
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "acaab4ef4507"
|
||||
down_revision = "b388730a2899"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
update(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.status == ConnectorCredentialPairStatus.INVALID)
|
||||
.values(status=ConnectorCredentialPairStatus.ACTIVE)
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add pinned assistants
|
||||
|
||||
Revision ID: aeda5f2df4f6
|
||||
Revises: c5eae4a75a1b
|
||||
Create Date: 2025-01-09 16:04:10.770636
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "aeda5f2df4f6"
|
||||
down_revision = "c5eae4a75a1b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("pinned_assistants", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.execute('UPDATE "user" SET pinned_assistants = chosen_assistants')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "pinned_assistants")
|
||||
@@ -0,0 +1,31 @@
|
||||
"""nullable preferences
|
||||
|
||||
Revision ID: b388730a2899
|
||||
Revises: 1a03d2c2856b
|
||||
Create Date: 2025-02-17 18:49:22.643902
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b388730a2899"
|
||||
down_revision = "1a03d2c2856b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("user", "temperature_override_enabled", nullable=True)
|
||||
op.alter_column("user", "auto_scroll", nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Ensure no null values before making columns non-nullable
|
||||
op.execute(
|
||||
'UPDATE "user" SET temperature_override_enabled = false WHERE temperature_override_enabled IS NULL'
|
||||
)
|
||||
op.execute('UPDATE "user" SET auto_scroll = false WHERE auto_scroll IS NULL')
|
||||
|
||||
op.alter_column("user", "temperature_override_enabled", nullable=False)
|
||||
op.alter_column("user", "auto_scroll", nullable=False)
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Add checkpointing/failure handling
|
||||
|
||||
Revision ID: b7a7eee5aa15
|
||||
Revises: f39c5794c10a
|
||||
Create Date: 2025-01-24 15:17:36.763172
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7a7eee5aa15"
|
||||
down_revision = "f39c5794c10a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("checkpoint_pointer", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("poll_range_start", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("poll_range_end", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_index_attempt_cc_pair_settings_poll",
|
||||
"index_attempt",
|
||||
[
|
||||
"connector_credential_pair_id",
|
||||
"search_settings_id",
|
||||
"status",
|
||||
sa.text("time_updated DESC"),
|
||||
],
|
||||
)
|
||||
|
||||
# Drop the old IndexAttemptError table
|
||||
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
|
||||
op.drop_table("index_attempt_errors")
|
||||
|
||||
# Create the new version of the table
|
||||
op.create_table(
|
||||
"index_attempt_errors",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("index_attempt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
sa.Column("document_link", sa.String(), nullable=True),
|
||||
sa.Column("entity_id", sa.String(), nullable=True),
|
||||
sa.Column("failed_time_range_start", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("failed_time_range_end", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("failure_message", sa.Text(), nullable=False),
|
||||
sa.Column("is_resolved", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["index_attempt_id"],
|
||||
["index_attempt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("SET lock_timeout = '5s'")
|
||||
|
||||
# try a few times to drop the table, this has been observed to fail due to other locks
|
||||
# blocking the drop
|
||||
NUM_TRIES = 10
|
||||
for i in range(NUM_TRIES):
|
||||
try:
|
||||
op.drop_table("index_attempt_errors")
|
||||
break
|
||||
except Exception as e:
|
||||
if i == NUM_TRIES - 1:
|
||||
raise e
|
||||
print(f"Error dropping table: {e}. Retrying...")
|
||||
|
||||
op.execute("SET lock_timeout = DEFAULT")
|
||||
|
||||
# Recreate the old IndexAttemptError table
|
||||
op.create_table(
|
||||
"index_attempt_errors",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
|
||||
sa.Column("batch", sa.Integer(), nullable=True),
|
||||
sa.Column("doc_summaries", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column("traceback", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["index_attempt_id"],
|
||||
["index_attempt.id"],
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"index_attempt_id",
|
||||
"index_attempt_errors",
|
||||
["time_created"],
|
||||
)
|
||||
|
||||
op.drop_index("ix_index_attempt_cc_pair_settings_poll")
|
||||
op.drop_column("index_attempt", "checkpoint_pointer")
|
||||
op.drop_column("index_attempt", "poll_range_start")
|
||||
op.drop_column("index_attempt", "poll_range_end")
|
||||
38
backend/alembic/versions/be2ab2aa50ee_fix_capitalization.py
Normal file
38
backend/alembic/versions/be2ab2aa50ee_fix_capitalization.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""fix_capitalization
|
||||
|
||||
Revision ID: be2ab2aa50ee
|
||||
Revises: 369644546676
|
||||
Create Date: 2025-01-10 13:13:26.228960
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "be2ab2aa50ee"
|
||||
down_revision = "369644546676"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE document
|
||||
SET
|
||||
external_user_group_ids = ARRAY(
|
||||
SELECT LOWER(unnest(external_user_group_ids))
|
||||
),
|
||||
last_modified = NOW()
|
||||
WHERE
|
||||
external_user_group_ids IS NOT NULL
|
||||
AND external_user_group_ids::text[] <> ARRAY(
|
||||
SELECT LOWER(unnest(external_user_group_ids))
|
||||
)::text[]
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# No way to cleanly persist the bad state through an upgrade/downgrade
|
||||
# cycle, so we just pass
|
||||
pass
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Add chat_message__standard_answer table
|
||||
|
||||
Revision ID: c5eae4a75a1b
|
||||
Revises: 0f7ff6d75b57
|
||||
Create Date: 2025-01-15 14:08:49.688998
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c5eae4a75a1b"
|
||||
down_revision = "0f7ff6d75b57"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chat_message__standard_answer",
|
||||
sa.Column("chat_message_id", sa.Integer(), nullable=False),
|
||||
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chat_message_id"],
|
||||
["chat_message.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_id"],
|
||||
["standard_answer.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("chat_message_id", "standard_answer_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("chat_message__standard_answer")
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Add has_been_indexed to DocumentByConnectorCredentialPair
|
||||
|
||||
Revision ID: c7bf5721733e
|
||||
Revises: fec3db967bf7
|
||||
Create Date: 2025-01-13 12:39:05.831693
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c7bf5721733e"
|
||||
down_revision = "027381bce97c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# assume all existing rows have been indexed, no better approach
|
||||
op.add_column(
|
||||
"document_by_connector_credential_pair",
|
||||
sa.Column("has_been_indexed", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE document_by_connector_credential_pair SET has_been_indexed = TRUE"
|
||||
)
|
||||
op.alter_column(
|
||||
"document_by_connector_credential_pair",
|
||||
"has_been_indexed",
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Add index to optimize get_document_counts_for_cc_pairs query pattern
|
||||
op.create_index(
|
||||
"idx_document_cc_pair_counts",
|
||||
"document_by_connector_credential_pair",
|
||||
["connector_id", "credential_id", "has_been_indexed"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the index first before removing the column
|
||||
op.drop_index(
|
||||
"idx_document_cc_pair_counts",
|
||||
table_name="document_by_connector_credential_pair",
|
||||
)
|
||||
op.drop_column("document_by_connector_credential_pair", "has_been_indexed")
|
||||
@@ -0,0 +1,80 @@
|
||||
"""add default slack channel config
|
||||
|
||||
Revision ID: eaa3b5593925
|
||||
Revises: 98a5008d8711
|
||||
Create Date: 2025-02-03 18:07:56.552526
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "eaa3b5593925"
|
||||
down_revision = "98a5008d8711"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add is_default column
|
||||
op.add_column(
|
||||
"slack_channel_config",
|
||||
sa.Column("is_default", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_slack_channel_config_slack_bot_id_default",
|
||||
"slack_channel_config",
|
||||
["slack_bot_id", "is_default"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_default IS TRUE"),
|
||||
)
|
||||
|
||||
# Create default channel configs for existing slack bots without one
|
||||
conn = op.get_bind()
|
||||
slack_bots = conn.execute(sa.text("SELECT id FROM slack_bot")).fetchall()
|
||||
|
||||
for slack_bot in slack_bots:
|
||||
slack_bot_id = slack_bot[0]
|
||||
existing_default = conn.execute(
|
||||
sa.text(
|
||||
"SELECT id FROM slack_channel_config WHERE slack_bot_id = :bot_id AND is_default = TRUE"
|
||||
),
|
||||
{"bot_id": slack_bot_id},
|
||||
).fetchone()
|
||||
|
||||
if not existing_default:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_channel_config (
|
||||
slack_bot_id, persona_id, channel_config, enable_auto_filters, is_default
|
||||
) VALUES (
|
||||
:bot_id, NULL,
|
||||
'{"channel_name": null, '
|
||||
'"respond_member_group_list": [], '
|
||||
'"answer_filters": [], '
|
||||
'"follow_up_tags": [], '
|
||||
'"respond_tag_only": true}',
|
||||
FALSE, TRUE
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"bot_id": slack_bot_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete default slack channel configs
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DELETE FROM slack_channel_config WHERE is_default = TRUE"))
|
||||
|
||||
# Remove index
|
||||
op.drop_index(
|
||||
"ix_slack_channel_config_slack_bot_id_default",
|
||||
table_name="slack_channel_config",
|
||||
)
|
||||
|
||||
# Remove is_default column
|
||||
op.drop_column("slack_channel_config", "is_default")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add passthrough auth to tool
|
||||
|
||||
Revision ID: f1ca58b2f2ec
|
||||
Revises: c7bf5721733e
|
||||
Create Date: 2024-03-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f1ca58b2f2ec"
|
||||
down_revision: Union[str, None] = "c7bf5721733e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add passthrough_auth column to tool table with default value of False
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove passthrough_auth column from tool table
|
||||
op.drop_column("tool", "passthrough_auth")
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Add background errors table
|
||||
|
||||
Revision ID: f39c5794c10a
|
||||
Revises: 2cdeff6d8c93
|
||||
Create Date: 2025-02-12 17:11:14.527876
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f39c5794c10a"
|
||||
down_revision = "2cdeff6d8c93"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"background_error",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("message", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["cc_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("background_error")
|
||||
@@ -0,0 +1,53 @@
|
||||
"""delete non-search assistants
|
||||
|
||||
Revision ID: f5437cc136c5
|
||||
Revises: eaa3b5593925
|
||||
Create Date: 2025-02-04 16:17:15.677256
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f5437cc136c5"
|
||||
down_revision = "eaa3b5593925"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Fix: split the statements into multiple op.execute() calls
|
||||
op.execute(
|
||||
"""
|
||||
WITH personas_without_search AS (
|
||||
SELECT p.id
|
||||
FROM persona p
|
||||
LEFT JOIN persona__tool pt ON p.id = pt.persona_id
|
||||
LEFT JOIN tool t ON pt.tool_id = t.id
|
||||
GROUP BY p.id
|
||||
HAVING COUNT(CASE WHEN t.in_code_tool_id = 'run_search' THEN 1 END) = 0
|
||||
)
|
||||
UPDATE slack_channel_config
|
||||
SET persona_id = NULL
|
||||
WHERE is_default = TRUE AND persona_id IN (SELECT id FROM personas_without_search)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
WITH personas_without_search AS (
|
||||
SELECT p.id
|
||||
FROM persona p
|
||||
LEFT JOIN persona__tool pt ON p.id = pt.persona_id
|
||||
LEFT JOIN tool t ON pt.tool_id = t.id
|
||||
GROUP BY p.id
|
||||
HAVING COUNT(CASE WHEN t.in_code_tool_id = 'run_search' THEN 1 END) = 0
|
||||
)
|
||||
DELETE FROM slack_channel_config
|
||||
WHERE is_default = FALSE AND persona_id IN (SELECT id FROM personas_without_search)
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Add time_updated to UserGroup and DocumentSet
|
||||
|
||||
Revision ID: fec3db967bf7
|
||||
Revises: 97dbb53fa8c8
|
||||
Create Date: 2025-01-12 15:49:02.289100
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fec3db967bf7"
|
||||
down_revision = "97dbb53fa8c8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"document_set",
|
||||
sa.Column(
|
||||
"time_last_modified_by_user",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user_group",
|
||||
sa.Column(
|
||||
"time_last_modified_by_user",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_group", "time_last_modified_by_user")
|
||||
op.drop_column("document_set", "time_last_modified_by_user")
|
||||
@@ -21,7 +21,7 @@ logger = setup_logger()
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ def perform_ttl_management_task(
|
||||
|
||||
@celery_app.task(
|
||||
name="check_ttl_management_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
@@ -43,7 +44,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
@@ -56,11 +57,12 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@celery_app.task(
|
||||
name="autogenerate_usage_report_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
|
||||
@@ -2,23 +2,79 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
beat_cloud_tasks as base_beat_system_tasks,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
beat_task_templates as base_beat_task_templates,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import generate_cloud_tasks
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
get_tasks_to_schedule as base_get_tasks_to_schedule,
|
||||
)
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate_usage_report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
},
|
||||
]
|
||||
ee_beat_system_tasks: list[dict] = []
|
||||
|
||||
ee_beat_task_templates: list[dict] = []
|
||||
ee_beat_task_templates.extend(
|
||||
[
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
|
||||
beat_system_tasks = ee_beat_system_tasks + base_beat_system_tasks
|
||||
beat_task_templates = ee_beat_task_templates + base_beat_task_templates
|
||||
cloud_tasks = generate_cloud_tasks(
|
||||
beat_system_tasks, beat_task_templates, beat_multiplier
|
||||
)
|
||||
return cloud_tasks
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_tasks_to_schedule + base_tasks_to_schedule
|
||||
return ee_tasks_to_schedule + base_get_tasks_to_schedule()
|
||||
|
||||
@@ -8,6 +8,9 @@ from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import mark_user_group_as_synced
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -43,24 +46,59 @@ def monitor_usergroup_taskset(
|
||||
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=count,
|
||||
)
|
||||
return
|
||||
|
||||
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
||||
if user_group:
|
||||
usergroup_name = user_group.name
|
||||
if user_group.is_up_for_deletion:
|
||||
# this prepare should have been run when the deletion was scheduled,
|
||||
# but run it again to be sure we're ready to go
|
||||
mark_user_group_as_synced(db_session, user_group)
|
||||
prepare_user_group_for_deletion(db_session, usergroup_id)
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(
|
||||
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(
|
||||
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
|
||||
try:
|
||||
if user_group.is_up_for_deletion:
|
||||
# this prepare should have been run when the deletion was scheduled,
|
||||
# but run it again to be sure we're ready to go
|
||||
mark_user_group_as_synced(db_session, user_group)
|
||||
prepare_user_group_for_deletion(db_session, usergroup_id)
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
raise e
|
||||
|
||||
rug.reset()
|
||||
|
||||
@@ -4,6 +4,20 @@ import os
|
||||
# Applicable for OIDC Auth
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
|
||||
|
||||
# Applicable for OIDC Auth, allows you to override the scopes that
|
||||
# are requested from the OIDC provider. Currently used when passing
|
||||
# over access tokens to tool calls and the tool needs more scopes
|
||||
OIDC_SCOPE_OVERRIDE: list[str] | None = None
|
||||
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
|
||||
|
||||
if _OIDC_SCOPE_OVERRIDE:
|
||||
try:
|
||||
OIDC_SCOPE_OVERRIDE = [
|
||||
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Applicable for SAML Auth
|
||||
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"
|
||||
|
||||
@@ -63,3 +77,5 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
@@ -5,7 +5,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import prefix_group_w_source
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Document as DbDocument
|
||||
|
||||
@@ -25,7 +25,7 @@ def upsert_document_external_perms__no_commit(
|
||||
).first()
|
||||
|
||||
prefixed_external_groups = [
|
||||
prefix_group_w_source(
|
||||
build_ext_group_name_for_onyx(
|
||||
ext_group_name=group_id,
|
||||
source=source_type,
|
||||
)
|
||||
@@ -66,7 +66,7 @@ def upsert_document_external_perms(
|
||||
).first()
|
||||
|
||||
prefixed_external_groups: set[str] = {
|
||||
prefix_group_w_source(
|
||||
build_ext_group_name_for_onyx(
|
||||
ext_group_name=group_id,
|
||||
source=source_type,
|
||||
)
|
||||
|
||||
@@ -6,8 +6,9 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.utils import prefix_group_w_source
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
@@ -61,8 +62,10 @@ def replace_user__ext_group_for_cc_pair(
|
||||
all_group_member_emails.add(user_email)
|
||||
|
||||
# batch add users if they don't exist and get their ids
|
||||
all_group_members = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session, emails=list(all_group_member_emails)
|
||||
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
# NOTE: this function handles case sensitivity for emails
|
||||
emails=list(all_group_member_emails),
|
||||
)
|
||||
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
@@ -84,12 +87,14 @@ def replace_user__ext_group_for_cc_pair(
|
||||
f" with email {user_email} not found"
|
||||
)
|
||||
continue
|
||||
external_group_id = build_ext_group_name_for_onyx(
|
||||
ext_group_name=external_group.id,
|
||||
source=source,
|
||||
)
|
||||
new_external_permissions.append(
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=prefix_group_w_source(
|
||||
external_group.id, source
|
||||
),
|
||||
external_user_group_id=external_group_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,8 +2,11 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
def make_persona_private(
|
||||
@@ -12,6 +15,9 @@ def make_persona_private(
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
|
||||
dedupe the inputs, the commit will exception."""
|
||||
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -20,11 +26,22 @@ def make_persona_private(
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
for user_uuid in user_ids:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))
|
||||
user_ids_set = set(user_ids)
|
||||
for user_id in user_ids_set:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
|
||||
|
||||
create_notification(
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if group_ids:
|
||||
for group_id in group_ids:
|
||||
group_ids_set = set(group_ids)
|
||||
for group_id in group_ids_set:
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
@@ -1,27 +1,138 @@
|
||||
import datetime
|
||||
from typing import Literal
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import asc
|
||||
from sqlalchemy import BinaryExpression
|
||||
from sqlalchemy import ColumnElement
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import distinct
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import case
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.sql import select
|
||||
from sqlalchemy.sql.expression import literal
|
||||
from sqlalchemy.sql.expression import UnaryExpression
|
||||
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessageFeedback
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
SortByOptions = Literal["time_sent"]
|
||||
|
||||
def _build_filter_conditions(
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
feedback_filter: QAFeedbackType | None,
|
||||
) -> list[ColumnElement]:
|
||||
"""
|
||||
Helper function to build all filter conditions for chat sessions.
|
||||
Filters by start and end time, feedback type, and any sessions without messages.
|
||||
start_time: Date from which to filter
|
||||
end_time: Date to which to filter
|
||||
feedback_filter: Feedback type to filter by
|
||||
Returns: List of filter conditions
|
||||
"""
|
||||
conditions = []
|
||||
|
||||
if start_time is not None:
|
||||
conditions.append(ChatSession.time_created >= start_time)
|
||||
if end_time is not None:
|
||||
conditions.append(ChatSession.time_created <= end_time)
|
||||
|
||||
if feedback_filter is not None:
|
||||
feedback_subq = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.join(ChatMessageFeedback)
|
||||
.group_by(ChatMessage.chat_session_id)
|
||||
.having(
|
||||
case(
|
||||
(
|
||||
case(
|
||||
{literal(feedback_filter == QAFeedbackType.LIKE): True},
|
||||
else_=False,
|
||||
),
|
||||
func.bool_and(ChatMessageFeedback.is_positive),
|
||||
),
|
||||
(
|
||||
case(
|
||||
{literal(feedback_filter == QAFeedbackType.DISLIKE): True},
|
||||
else_=False,
|
||||
),
|
||||
func.bool_and(func.not_(ChatMessageFeedback.is_positive)),
|
||||
),
|
||||
else_=func.bool_or(ChatMessageFeedback.is_positive)
|
||||
& func.bool_or(func.not_(ChatMessageFeedback.is_positive)),
|
||||
)
|
||||
)
|
||||
)
|
||||
conditions.append(ChatSession.id.in_(feedback_subq))
|
||||
|
||||
return conditions
|
||||
|
||||
|
||||
def get_total_filtered_chat_sessions_count(
|
||||
db_session: Session,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
feedback_filter: QAFeedbackType | None,
|
||||
) -> int:
|
||||
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
|
||||
stmt = (
|
||||
select(func.count(distinct(ChatSession.id)))
|
||||
.select_from(ChatSession)
|
||||
.filter(*conditions)
|
||||
)
|
||||
return db_session.scalar(stmt) or 0
|
||||
|
||||
|
||||
def get_page_of_chat_sessions(
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
db_session: Session,
|
||||
page_num: int,
|
||||
page_size: int,
|
||||
feedback_filter: QAFeedbackType | None = None,
|
||||
) -> Sequence[ChatSession]:
|
||||
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
|
||||
|
||||
subquery = (
|
||||
select(ChatSession.id)
|
||||
.filter(*conditions)
|
||||
.order_by(desc(ChatSession.time_created), ChatSession.id)
|
||||
.limit(page_size)
|
||||
.offset(page_num * page_size)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(ChatSession)
|
||||
.join(subquery, ChatSession.id == subquery.c.id)
|
||||
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
|
||||
.options(
|
||||
joinedload(ChatSession.user),
|
||||
joinedload(ChatSession.persona),
|
||||
contains_eager(ChatSession.messages).joinedload(
|
||||
ChatMessage.chat_message_feedbacks
|
||||
),
|
||||
)
|
||||
.order_by(
|
||||
desc(ChatSession.time_created),
|
||||
ChatSession.id,
|
||||
asc(ChatMessage.id), # Ensure chronological message order
|
||||
)
|
||||
)
|
||||
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def fetch_chat_sessions_eagerly_by_time(
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
db_session: Session,
|
||||
limit: int | None = 500,
|
||||
initial_time: datetime.datetime | None = None,
|
||||
initial_time: datetime | None = None,
|
||||
) -> list[ChatSession]:
|
||||
time_order: UnaryExpression = desc(ChatSession.time_created)
|
||||
message_order: UnaryExpression = asc(ChatMessage.id)
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.constants import TokenRateLimitScope
|
||||
from onyx.db.models import TokenRateLimit
|
||||
@@ -52,11 +51,8 @@ def _add_user_filters(
|
||||
|
||||
# If user is None, this is an anonymous user and we should only show public token_rate_limits
|
||||
if user is None:
|
||||
if anonymous_user_enabled():
|
||||
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
|
||||
return stmt.where(where_clause)
|
||||
else:
|
||||
raise ValueError("User not authenticated")
|
||||
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
|
||||
return stmt.where(where_clause)
|
||||
|
||||
where_clause = User__UG.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
@@ -115,10 +111,10 @@ def insert_user_group_token_rate_limit(
|
||||
return token_limit
|
||||
|
||||
|
||||
def fetch_user_group_token_rate_limits(
|
||||
def fetch_user_group_token_rate_limits_for_user(
|
||||
db_session: Session,
|
||||
group_id: int,
|
||||
user: User | None = None,
|
||||
user: User | None,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
|
||||
@@ -218,14 +218,14 @@ def fetch_user_groups_for_user(
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def construct_document_select_by_usergroup(
|
||||
def construct_document_id_select_by_usergroup(
|
||||
user_group_id: int,
|
||||
) -> Select:
|
||||
"""This returns a statement that should be executed using
|
||||
.yield_per() to minimize overhead. The primary consumers of this function
|
||||
are background processing task generators."""
|
||||
stmt = (
|
||||
select(Document)
|
||||
select(Document.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
@@ -374,7 +374,9 @@ def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
|
||||
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
|
||||
db_user_group = UserGroup(name=user_group.name)
|
||||
db_user_group = UserGroup(
|
||||
name=user_group.name, time_last_modified_by_user=func.now()
|
||||
)
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
@@ -630,6 +632,10 @@ def update_user_group(
|
||||
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
|
||||
).unique()
|
||||
_validate_curator_status__no_commit(db_session, list(removed_users))
|
||||
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -699,7 +705,10 @@ def delete_user_group_cc_pair_relationship__no_commit(
|
||||
connector_credential_pair_id matches the given cc_pair_id.
|
||||
|
||||
Should be used very carefully (only for connectors that are being deleted)."""
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist")
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -257,6 +258,7 @@ def _fetch_all_page_restrictions(
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
@@ -265,6 +267,12 @@ def _fetch_all_page_restrictions(
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
@@ -334,7 +342,7 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -357,8 +365,16 @@ def confluence_doc_sync(
|
||||
|
||||
slim_docs = []
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents(
|
||||
callback=callback
|
||||
):
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync", 1)
|
||||
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
@@ -367,4 +383,5 @@ def confluence_doc_sync(
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
@@ -10,33 +11,51 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_client: OnyxConfluence, cc_pair_id: int
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
continue
|
||||
email = user.get("email")
|
||||
for user in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user}")
|
||||
|
||||
email = user.email
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
user_name = user.get("username")
|
||||
user_name = user.username
|
||||
# If it is present, try to get the email using a Server-specific method
|
||||
if user_name:
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
msg = f"user result missing email field: {user}"
|
||||
if user.type == "app":
|
||||
logger.warning(msg)
|
||||
else:
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
continue
|
||||
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user.user_id):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not all_users_groups:
|
||||
msg = f"No groups found for user with email: {email}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
else:
|
||||
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
|
||||
|
||||
if not group_member_emails:
|
||||
msg = "No groups found for any users."
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
|
||||
return group_member_emails
|
||||
|
||||
@@ -52,6 +71,7 @@ def confluence_group_sync(
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
all_found_emails = set()
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -14,6 +15,7 @@ logger = setup_logger()
|
||||
def _get_slim_doc_generator(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
gmail_connector: GmailConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
@@ -23,12 +25,14 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
return gmail_connector.retrieve_all_slim_documents(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -39,11 +43,19 @@ def gmail_doc_sync(
|
||||
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
|
||||
gmail_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector)
|
||||
slim_doc_generator = _get_slim_doc_generator(
|
||||
cc_pair, gmail_connector, callback=callback
|
||||
)
|
||||
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gmail_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -20,6 +21,7 @@ _PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
|
||||
def _get_slim_doc_generator(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
@@ -29,7 +31,9 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
return google_drive_connector.retrieve_all_slim_documents(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
@@ -42,24 +46,22 @@ def _fetch_permissions_for_permission_ids(
|
||||
if not permission_info or not doc_id:
|
||||
return []
|
||||
|
||||
# Check cache first for all permission IDs
|
||||
permissions = [
|
||||
_PERMISSION_ID_PERMISSION_MAP[pid]
|
||||
for pid in permission_ids
|
||||
if pid in _PERMISSION_ID_PERMISSION_MAP
|
||||
]
|
||||
|
||||
# If we found all permissions in cache, return them
|
||||
if len(permissions) == len(permission_ids):
|
||||
return permissions
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
|
||||
drive_service = get_drive_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# Otherwise, fetch all permissions and update cache
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
@@ -69,7 +71,6 @@ def _fetch_permissions_for_permission_ids(
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
# Update cache and return all permissions
|
||||
for permission in fetched_permissions:
|
||||
permissions_for_doc_id.append(permission)
|
||||
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
|
||||
@@ -120,15 +121,18 @@ def _get_permissions_from_slim_doc(
|
||||
elif permission_type == "anyone":
|
||||
public = True
|
||||
|
||||
drive_id = permission_info.get("drive_id")
|
||||
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_emails,
|
||||
external_user_group_ids=group_ids,
|
||||
is_public=public,
|
||||
)
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -146,6 +150,12 @@ def gdrive_doc_sync(
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
|
||||
@@ -1,16 +1,127 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.resources import AdminService
|
||||
from onyx.connectors.google_utils.resources import get_admin_service
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_drive_members(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
) -> dict[str, tuple[set[str], set[str]]]:
|
||||
"""
|
||||
This builds a map of drive ids to their members (group and user emails).
|
||||
E.g. {
|
||||
"drive_id_1": ({"group_email_1"}, {"user_email_1", "user_email_2"}),
|
||||
"drive_id_2": ({"group_email_3"}, {"user_email_3"}),
|
||||
}
|
||||
"""
|
||||
drive_ids = google_drive_connector.get_all_drive_ids()
|
||||
|
||||
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {}
|
||||
drive_service = get_drive_service(
|
||||
google_drive_connector.creds,
|
||||
google_drive_connector.primary_admin_email,
|
||||
)
|
||||
|
||||
for drive_id in drive_ids:
|
||||
group_emails: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
for permission in execute_paginated_retrieval(
|
||||
drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=drive_id,
|
||||
fields="permissions(emailAddress, type)",
|
||||
supportsAllDrives=True,
|
||||
):
|
||||
if permission["type"] == "group":
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission["type"] == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
drive_id_to_members_map[drive_id] = (group_emails, user_emails)
|
||||
return drive_id_to_members_map
|
||||
|
||||
|
||||
def _get_all_groups(
|
||||
admin_service: AdminService,
|
||||
google_domain: str,
|
||||
) -> set[str]:
|
||||
"""
|
||||
This gets all the group emails.
|
||||
"""
|
||||
group_emails: set[str] = set()
|
||||
for group in execute_paginated_retrieval(
|
||||
admin_service.groups().list,
|
||||
list_key="groups",
|
||||
domain=google_domain,
|
||||
fields="groups(email)",
|
||||
):
|
||||
group_emails.add(group["email"])
|
||||
return group_emails
|
||||
|
||||
|
||||
def _map_group_email_to_member_emails(
|
||||
admin_service: AdminService,
|
||||
group_emails: set[str],
|
||||
) -> dict[str, set[str]]:
|
||||
"""
|
||||
This maps group emails to their member emails.
|
||||
"""
|
||||
group_to_member_map: dict[str, set[str]] = {}
|
||||
for group_email in group_emails:
|
||||
group_member_emails: set[str] = set()
|
||||
for member in execute_paginated_retrieval(
|
||||
admin_service.members().list,
|
||||
list_key="members",
|
||||
groupKey=group_email,
|
||||
fields="members(email)",
|
||||
):
|
||||
group_member_emails.add(member["email"])
|
||||
|
||||
group_to_member_map[group_email] = group_member_emails
|
||||
return group_to_member_map
|
||||
|
||||
|
||||
def _build_onyx_groups(
|
||||
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
|
||||
group_email_to_member_emails_map: dict[str, set[str]],
|
||||
) -> list[ExternalUserGroup]:
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
|
||||
# Convert all drive member definitions to onyx groups
|
||||
# This is because having drive level access means you have
|
||||
# irrevocable access to all the files in the drive.
|
||||
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
|
||||
all_member_emails: set[str] = user_emails
|
||||
for group_email in group_emails:
|
||||
all_member_emails.update(group_email_to_member_emails_map[group_email])
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=drive_id,
|
||||
user_emails=list(all_member_emails),
|
||||
)
|
||||
)
|
||||
|
||||
# Convert all group member definitions to onyx groups
|
||||
for group_email, member_emails in group_email_to_member_emails_map.items():
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_email,
|
||||
user_emails=list(member_emails),
|
||||
)
|
||||
)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
|
||||
def gdrive_group_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
# Initialize connector and build credential/service objects
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
@@ -19,34 +130,23 @@ def gdrive_group_sync(
|
||||
google_drive_connector.creds, google_drive_connector.primary_admin_email
|
||||
)
|
||||
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
for group in execute_paginated_retrieval(
|
||||
admin_service.groups().list,
|
||||
list_key="groups",
|
||||
domain=google_drive_connector.google_domain,
|
||||
fields="groups(email)",
|
||||
):
|
||||
# The id is the group email
|
||||
group_email = group["email"]
|
||||
# Get all drive members
|
||||
drive_id_to_members_map = _get_drive_members(google_drive_connector)
|
||||
|
||||
# Gather group member emails
|
||||
group_member_emails: list[str] = []
|
||||
for member in execute_paginated_retrieval(
|
||||
admin_service.members().list,
|
||||
list_key="members",
|
||||
groupKey=group_email,
|
||||
fields="members(email)",
|
||||
):
|
||||
group_member_emails.append(member["email"])
|
||||
# Get all group emails
|
||||
all_group_emails = _get_all_groups(
|
||||
admin_service, google_drive_connector.google_domain
|
||||
)
|
||||
|
||||
if not group_member_emails:
|
||||
continue
|
||||
# Map group emails to their members
|
||||
group_email_to_member_emails_map = _map_group_email_to_member_emails(
|
||||
admin_service, all_group_emails
|
||||
)
|
||||
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_email,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
)
|
||||
# Convert the maps to onyx groups
|
||||
onyx_groups = _build_onyx_groups(
|
||||
drive_id_to_members_map=drive_id_to_members_map,
|
||||
group_email_to_member_emails_map=group_email_to_member_emails_map,
|
||||
)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
@@ -161,7 +161,10 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales
|
||||
|
||||
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
|
||||
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"CC pair {cc_pair_id} not found")
|
||||
credential_json = cc_pair.credential.credential_json
|
||||
|
||||
@@ -5,8 +5,9 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -14,12 +15,12 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents()
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
@@ -31,6 +32,14 @@ def _get_slack_document_ids_and_channels(
|
||||
channel_doc_map[channel_id] = []
|
||||
channel_doc_map[channel_id].append(doc_metadata.id)
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
return channel_doc_map
|
||||
|
||||
|
||||
@@ -114,7 +123,7 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -127,7 +136,7 @@ def slack_doc_sync(
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
cc_pair=cc_pair,
|
||||
cc_pair=cc_pair, callback=callback
|
||||
)
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
|
||||
@@ -15,11 +15,13 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
list[DocExternalAccess],
|
||||
]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import BASE_SCOPES
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
|
||||
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
|
||||
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
@@ -88,7 +90,13 @@ def get_application() -> FastAPI:
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
OpenID(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
OPENID_CONFIG_URL,
|
||||
# BASE_SCOPES is the same as not setting this
|
||||
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
|
||||
),
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
associate_by_email=True,
|
||||
|
||||
@@ -80,7 +80,7 @@ def oneoff_standard_answers(
|
||||
def _handle_standard_answers(
|
||||
message_info: SlackMessageInfo,
|
||||
receiver_ids: list[str] | None,
|
||||
slack_channel_config: SlackChannelConfig | None,
|
||||
slack_channel_config: SlackChannelConfig,
|
||||
prompt: Prompt | None,
|
||||
logger: OnyxLoggingAdapter,
|
||||
client: WebClient,
|
||||
@@ -94,13 +94,10 @@ def _handle_standard_answers(
|
||||
Returns True if standard answers are found to match the user's message and therefore,
|
||||
we still need to respond to the users.
|
||||
"""
|
||||
# if no channel config, then no standard answers are configured
|
||||
if not slack_channel_config:
|
||||
return False
|
||||
|
||||
slack_thread_id = message_info.thread_to_respond
|
||||
configured_standard_answer_categories = (
|
||||
slack_channel_config.standard_answer_categories if slack_channel_config else []
|
||||
slack_channel_config.standard_answer_categories
|
||||
)
|
||||
configured_standard_answers = set(
|
||||
[
|
||||
@@ -150,9 +147,9 @@ def _handle_standard_answers(
|
||||
db_session=db_session,
|
||||
description="",
|
||||
user_id=None,
|
||||
persona_id=slack_channel_config.persona.id
|
||||
if slack_channel_config.persona
|
||||
else 0,
|
||||
persona_id=(
|
||||
slack_channel_config.persona.id if slack_channel_config.persona else 0
|
||||
),
|
||||
onyxbot_flow=True,
|
||||
slack_thread_id=slack_thread_id,
|
||||
)
|
||||
@@ -182,7 +179,7 @@ def _handle_standard_answers(
|
||||
formatted_answers.append(formatted_answer)
|
||||
answer_message = "\n\n".join(formatted_answers)
|
||||
|
||||
_ = create_new_chat_message(
|
||||
chat_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
@@ -191,8 +188,13 @@ def _handle_standard_answers(
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
commit=False,
|
||||
)
|
||||
# attach the standard answers to the chat message
|
||||
chat_message.standard_answers = [
|
||||
standard_answer for standard_answer, _ in matching_standard_answers
|
||||
]
|
||||
db_session.commit()
|
||||
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
|
||||
@@ -10,6 +10,7 @@ from fastapi import Response
|
||||
from ee.onyx.auth.users import decode_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.auth.api_key import extract_tenant_from_api_key_header
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.engine import is_valid_schema_name
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -32,7 +33,7 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
return await call_next(request)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tenant ID middleware: {str(e)}")
|
||||
logger.exception(f"Error in tenant ID middleware: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -43,11 +44,12 @@ async def _get_tenant_id_from_request(
|
||||
Attempt to extract tenant_id from:
|
||||
1) The API key header
|
||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||
3) Reset token cookie
|
||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||
"""
|
||||
# Check for API key
|
||||
tenant_id = extract_tenant_from_api_key_header(request)
|
||||
if tenant_id:
|
||||
if tenant_id is not None:
|
||||
return tenant_id
|
||||
|
||||
# Check for anonymous user cookie
|
||||
@@ -62,6 +64,7 @@ async def _get_tenant_id_from_request(
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if not token_data:
|
||||
@@ -85,8 +88,18 @@ async def _get_tenant_id_from_request(
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
finally:
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# As a final step, check for explicit tenant_id cookie
|
||||
tenant_id_cookie = request.cookies.get(TENANT_ID_COOKIE_NAME)
|
||||
if tenant_id_cookie and is_valid_schema_name(tenant_id_cookie):
|
||||
return tenant_id_cookie
|
||||
|
||||
# If we've reached this point, return the default schema
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
@@ -36,12 +36,12 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -271,12 +271,12 @@ def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
@@ -286,6 +286,7 @@ def prepare_authorization_request(
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
session: str
|
||||
|
||||
if connector == DocumentSource.SLACK:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
@@ -328,7 +329,6 @@ def handle_slack_oauth_callback(
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
@@ -336,7 +336,7 @@ def handle_slack_oauth_callback(
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
@@ -522,7 +522,6 @@ def handle_google_drive_oauth_callback(
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
@@ -530,7 +529,7 @@ def handle_google_drive_oauth_callback(
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
@@ -554,6 +553,7 @@ def handle_google_drive_oauth_callback(
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
session: GoogleDriveOAuth.OAuthSession
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
|
||||
@@ -179,6 +179,7 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_agentic_search=chat_message_req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@@ -301,6 +302,7 @@ def handle_send_message_simple_with_history(
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_agentic_search=req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
@@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
@@ -120,9 +125,12 @@ class OneShotQARequest(ChunkContext):
|
||||
# will also disable Thread-based Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# If True, skips generative an AI response to the search query
|
||||
# If True, skips generating an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
|
||||
@@ -83,6 +83,7 @@ def handle_search_request(
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
skip_query_analysis=False,
|
||||
db_session=db_session,
|
||||
bypass_acl=False,
|
||||
)
|
||||
@@ -196,6 +197,8 @@ def get_answer_stream(
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
db_session=db_session,
|
||||
use_agentic_search=query_request.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -28,7 +28,7 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
|
||||
def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
|
||||
if user is None:
|
||||
# Unauthenticated users are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
@@ -52,8 +52,8 @@ User rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
user_rate_limits = fetch_all_user_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
@@ -94,7 +94,7 @@ User Group rate limits
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
||||
|
||||
if group_rate_limits:
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
|
||||
from ee.onyx.db.query_history import get_page_of_chat_sessions
|
||||
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from ee.onyx.server.query_history.models import MessageSnapshot
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
@@ -23,257 +27,15 @@ from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import User
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AbridgedSearchDoc(BaseModel):
|
||||
"""A subset of the info present in `SearchDoc`"""
|
||||
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
|
||||
|
||||
class MessageSnapshot(BaseModel):
|
||||
message: str
|
||||
message_type: MessageType
|
||||
documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
time_created: datetime
|
||||
|
||||
@classmethod
|
||||
def build(cls, message: ChatMessage) -> "MessageSnapshot":
|
||||
latest_messages_feedback_obj = (
|
||||
message.chat_message_feedbacks[-1]
|
||||
if len(message.chat_message_feedbacks) > 0
|
||||
else None
|
||||
)
|
||||
feedback_type = (
|
||||
(
|
||||
QAFeedbackType.LIKE
|
||||
if latest_messages_feedback_obj.is_positive
|
||||
else QAFeedbackType.DISLIKE
|
||||
)
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
feedback_text = (
|
||||
latest_messages_feedback_obj.feedback_text
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
message=message.message,
|
||||
message_type=message.message_type,
|
||||
documents=[
|
||||
AbridgedSearchDoc(
|
||||
document_id=document.document_id,
|
||||
semantic_identifier=document.semantic_id,
|
||||
link=document.link,
|
||||
)
|
||||
for document in message.search_docs
|
||||
],
|
||||
feedback_type=feedback_type,
|
||||
feedback_text=feedback_text,
|
||||
time_created=message.time_sent,
|
||||
)
|
||||
|
||||
|
||||
class ChatSessionMinimal(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
first_user_message: str
|
||||
first_ai_message: str
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
feedback_type: QAFeedbackType | Literal["mixed"] | None
|
||||
flow_type: SessionType
|
||||
conversation_length: int
|
||||
|
||||
|
||||
class ChatSessionSnapshot(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
messages: list[MessageSnapshot]
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
|
||||
class QuestionAnswerPairSnapshot(BaseModel):
|
||||
chat_session_id: UUID
|
||||
# 1-indexed message number in the chat_session
|
||||
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
|
||||
message_pair_num: int
|
||||
user_message: str
|
||||
ai_response: str
|
||||
retrieved_documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
persona_name: str | None
|
||||
user_email: str
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
@classmethod
|
||||
def from_chat_session_snapshot(
|
||||
cls,
|
||||
chat_session_snapshot: ChatSessionSnapshot,
|
||||
) -> list["QuestionAnswerPairSnapshot"]:
|
||||
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
|
||||
for ind in range(1, len(chat_session_snapshot.messages), 2):
|
||||
message_pairs.append(
|
||||
(
|
||||
chat_session_snapshot.messages[ind - 1],
|
||||
chat_session_snapshot.messages[ind],
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
cls(
|
||||
chat_session_id=chat_session_snapshot.id,
|
||||
message_pair_num=ind + 1,
|
||||
user_message=user_message.message,
|
||||
ai_response=ai_message.message,
|
||||
retrieved_documents=ai_message.documents,
|
||||
feedback_type=ai_message.feedback_type,
|
||||
feedback_text=ai_message.feedback_text,
|
||||
persona_name=chat_session_snapshot.assistant_name,
|
||||
user_email=get_display_email(chat_session_snapshot.user_email),
|
||||
time_created=user_message.time_created,
|
||||
flow_type=chat_session_snapshot.flow_type,
|
||||
)
|
||||
for ind, (user_message, ai_message) in enumerate(message_pairs)
|
||||
]
|
||||
|
||||
def to_json(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"chat_session_id": str(self.chat_session_id),
|
||||
"message_pair_num": str(self.message_pair_num),
|
||||
"user_message": self.user_message,
|
||||
"ai_response": self.ai_response,
|
||||
"retrieved_documents": "|".join(
|
||||
[
|
||||
doc.link or doc.semantic_identifier
|
||||
for doc in self.retrieved_documents
|
||||
]
|
||||
),
|
||||
"feedback_type": self.feedback_type.value if self.feedback_type else "",
|
||||
"feedback_text": self.feedback_text or "",
|
||||
"persona_name": self.persona_name,
|
||||
"user_email": self.user_email,
|
||||
"time_created": str(self.time_created),
|
||||
"flow_type": self.flow_type,
|
||||
}
|
||||
|
||||
|
||||
def determine_flow_type(chat_session: ChatSession) -> SessionType:
|
||||
return SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history_minimal(
|
||||
db_session: Session,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
feedback_filter: QAFeedbackType | None = None,
|
||||
limit: int | None = 500,
|
||||
) -> list[ChatSessionMinimal]:
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=start, end=end, db_session=db_session, limit=limit
|
||||
)
|
||||
|
||||
minimal_sessions = []
|
||||
for chat_session in chat_sessions:
|
||||
if not chat_session.messages:
|
||||
continue
|
||||
|
||||
first_user_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.USER
|
||||
),
|
||||
"",
|
||||
)
|
||||
first_ai_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.ASSISTANT
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
has_positive_feedback = any(
|
||||
feedback.is_positive
|
||||
for message in chat_session.messages
|
||||
for feedback in message.chat_message_feedbacks
|
||||
)
|
||||
|
||||
has_negative_feedback = any(
|
||||
not feedback.is_positive
|
||||
for message in chat_session.messages
|
||||
for feedback in message.chat_message_feedbacks
|
||||
)
|
||||
|
||||
feedback_type: QAFeedbackType | Literal["mixed"] | None = (
|
||||
"mixed"
|
||||
if has_positive_feedback and has_negative_feedback
|
||||
else QAFeedbackType.LIKE
|
||||
if has_positive_feedback
|
||||
else QAFeedbackType.DISLIKE
|
||||
if has_negative_feedback
|
||||
else None
|
||||
)
|
||||
|
||||
if feedback_filter:
|
||||
if feedback_filter == QAFeedbackType.LIKE and not has_positive_feedback:
|
||||
continue
|
||||
if feedback_filter == QAFeedbackType.DISLIKE and not has_negative_feedback:
|
||||
continue
|
||||
|
||||
flow_type = determine_flow_type(chat_session)
|
||||
|
||||
minimal_sessions.append(
|
||||
ChatSessionMinimal(
|
||||
id=chat_session.id,
|
||||
user_email=get_display_email(
|
||||
chat_session.user.email if chat_session.user else None
|
||||
),
|
||||
name=chat_session.description,
|
||||
first_user_message=first_user_message,
|
||||
first_ai_message=first_ai_message,
|
||||
assistant_id=chat_session.persona_id,
|
||||
assistant_name=(
|
||||
chat_session.persona.name if chat_session.persona else None
|
||||
),
|
||||
time_created=chat_session.time_created,
|
||||
feedback_type=feedback_type,
|
||||
flow_type=flow_type,
|
||||
conversation_length=len(
|
||||
[
|
||||
m
|
||||
for m in chat_session.messages
|
||||
if m.message_type != MessageType.SYSTEM
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return minimal_sessions
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
start: datetime,
|
||||
@@ -319,7 +81,7 @@ def snapshot_from_chat_session(
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
flow_type = determine_flow_type(chat_session)
|
||||
flow_type = SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
|
||||
|
||||
return ChatSessionSnapshot(
|
||||
id=chat_session.id,
|
||||
@@ -371,22 +133,38 @@ def get_user_chat_sessions(
|
||||
|
||||
@router.get("/admin/chat-session-history")
|
||||
def get_chat_session_history(
|
||||
page_num: int = Query(0, ge=0),
|
||||
page_size: int = Query(10, ge=1),
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[ChatSessionMinimal]:
|
||||
return fetch_and_process_chat_session_history_minimal(
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
db_session=db_session,
|
||||
start=start
|
||||
or (
|
||||
datetime.now(tz=timezone.utc) - timedelta(days=30)
|
||||
), # default is 30d lookback
|
||||
end=end or datetime.now(tz=timezone.utc),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
total_filtered_chat_sessions_count = get_total_filtered_chat_sessions_count(
|
||||
db_session=db_session,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
ChatSessionMinimal.from_chat_session(chat_session)
|
||||
for chat_session in page_of_chat_sessions
|
||||
],
|
||||
total_items=total_filtered_chat_sessions_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/chat-session-history/{chat_session_id}")
|
||||
def get_chat_session_admin(
|
||||
|
||||
218
backend/ee/onyx/server/query_history/models.py
Normal file
218
backend/ee/onyx/server/query_history/models.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
class AbridgedSearchDoc(BaseModel):
|
||||
"""A subset of the info present in `SearchDoc`"""
|
||||
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
|
||||
|
||||
class MessageSnapshot(BaseModel):
|
||||
id: int
|
||||
message: str
|
||||
message_type: MessageType
|
||||
documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
time_created: datetime
|
||||
|
||||
@classmethod
|
||||
def build(cls, message: ChatMessage) -> "MessageSnapshot":
|
||||
latest_messages_feedback_obj = (
|
||||
message.chat_message_feedbacks[-1]
|
||||
if len(message.chat_message_feedbacks) > 0
|
||||
else None
|
||||
)
|
||||
feedback_type = (
|
||||
(
|
||||
QAFeedbackType.LIKE
|
||||
if latest_messages_feedback_obj.is_positive
|
||||
else QAFeedbackType.DISLIKE
|
||||
)
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
feedback_text = (
|
||||
latest_messages_feedback_obj.feedback_text
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
id=message.id,
|
||||
message=message.message,
|
||||
message_type=message.message_type,
|
||||
documents=[
|
||||
AbridgedSearchDoc(
|
||||
document_id=document.document_id,
|
||||
semantic_identifier=document.semantic_id,
|
||||
link=document.link,
|
||||
)
|
||||
for document in message.search_docs
|
||||
],
|
||||
feedback_type=feedback_type,
|
||||
feedback_text=feedback_text,
|
||||
time_created=message.time_sent,
|
||||
)
|
||||
|
||||
|
||||
class ChatSessionMinimal(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
first_user_message: str
|
||||
first_ai_message: str
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
feedback_type: QAFeedbackType | None
|
||||
flow_type: SessionType
|
||||
conversation_length: int
|
||||
|
||||
@classmethod
|
||||
def from_chat_session(cls, chat_session: ChatSession) -> "ChatSessionMinimal":
|
||||
first_user_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.USER
|
||||
),
|
||||
"",
|
||||
)
|
||||
first_ai_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.ASSISTANT
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
list_of_message_feedbacks = [
|
||||
feedback.is_positive
|
||||
for message in chat_session.messages
|
||||
for feedback in message.chat_message_feedbacks
|
||||
]
|
||||
session_feedback_type = None
|
||||
if list_of_message_feedbacks:
|
||||
if all(list_of_message_feedbacks):
|
||||
session_feedback_type = QAFeedbackType.LIKE
|
||||
elif not any(list_of_message_feedbacks):
|
||||
session_feedback_type = QAFeedbackType.DISLIKE
|
||||
else:
|
||||
session_feedback_type = QAFeedbackType.MIXED
|
||||
|
||||
return cls(
|
||||
id=chat_session.id,
|
||||
user_email=get_display_email(
|
||||
chat_session.user.email if chat_session.user else None
|
||||
),
|
||||
name=chat_session.description,
|
||||
first_user_message=first_user_message,
|
||||
first_ai_message=first_ai_message,
|
||||
assistant_id=chat_session.persona_id,
|
||||
assistant_name=(
|
||||
chat_session.persona.name if chat_session.persona else None
|
||||
),
|
||||
time_created=chat_session.time_created,
|
||||
feedback_type=session_feedback_type,
|
||||
flow_type=SessionType.SLACK
|
||||
if chat_session.onyxbot_flow
|
||||
else SessionType.CHAT,
|
||||
conversation_length=len(
|
||||
[
|
||||
message
|
||||
for message in chat_session.messages
|
||||
if message.message_type != MessageType.SYSTEM
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ChatSessionSnapshot(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
messages: list[MessageSnapshot]
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
|
||||
class QuestionAnswerPairSnapshot(BaseModel):
|
||||
chat_session_id: UUID
|
||||
# 1-indexed message number in the chat_session
|
||||
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
|
||||
message_pair_num: int
|
||||
user_message: str
|
||||
ai_response: str
|
||||
retrieved_documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
persona_name: str | None
|
||||
user_email: str
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
@classmethod
|
||||
def from_chat_session_snapshot(
|
||||
cls,
|
||||
chat_session_snapshot: ChatSessionSnapshot,
|
||||
) -> list["QuestionAnswerPairSnapshot"]:
|
||||
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
|
||||
for ind in range(1, len(chat_session_snapshot.messages), 2):
|
||||
message_pairs.append(
|
||||
(
|
||||
chat_session_snapshot.messages[ind - 1],
|
||||
chat_session_snapshot.messages[ind],
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
cls(
|
||||
chat_session_id=chat_session_snapshot.id,
|
||||
message_pair_num=ind + 1,
|
||||
user_message=user_message.message,
|
||||
ai_response=ai_message.message,
|
||||
retrieved_documents=ai_message.documents,
|
||||
feedback_type=ai_message.feedback_type,
|
||||
feedback_text=ai_message.feedback_text,
|
||||
persona_name=chat_session_snapshot.assistant_name,
|
||||
user_email=get_display_email(chat_session_snapshot.user_email),
|
||||
time_created=user_message.time_created,
|
||||
flow_type=chat_session_snapshot.flow_type,
|
||||
)
|
||||
for ind, (user_message, ai_message) in enumerate(message_pairs)
|
||||
]
|
||||
|
||||
def to_json(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"chat_session_id": str(self.chat_session_id),
|
||||
"message_pair_num": str(self.message_pair_num),
|
||||
"user_message": self.user_message,
|
||||
"ai_response": self.ai_response,
|
||||
"retrieved_documents": "|".join(
|
||||
[
|
||||
doc.link or doc.semantic_identifier
|
||||
for doc in self.retrieved_documents
|
||||
]
|
||||
),
|
||||
"feedback_type": self.feedback_type.value if self.feedback_type else "",
|
||||
"feedback_text": self.feedback_text or "",
|
||||
"persona_name": self.persona_name,
|
||||
"user_email": self.user_email,
|
||||
"time_created": str(self.time_created),
|
||||
"flow_type": self.flow_type,
|
||||
}
|
||||
@@ -24,7 +24,7 @@ from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.features.persona.models import CreatePersonaRequest
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.store import store_settings as store_base_settings
|
||||
@@ -57,7 +57,7 @@ class SeedConfiguration(BaseModel):
|
||||
llms: list[LLMProviderUpsertRequest] | None = None
|
||||
admin_user_emails: list[str] | None = None
|
||||
seeded_logo_path: str | None = None
|
||||
personas: list[CreatePersonaRequest] | None = None
|
||||
personas: list[PersonaUpsertRequest] | None = None
|
||||
settings: Settings | None = None
|
||||
enterprise_settings: EnterpriseSettings | None = None
|
||||
|
||||
@@ -128,7 +128,7 @@ def _seed_llms(
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
if personas:
|
||||
logger.notice("Seeding Personas")
|
||||
for persona in personas:
|
||||
|
||||
@@ -18,11 +18,16 @@ from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
@@ -34,18 +39,17 @@ from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
@@ -54,13 +58,14 @@ router = APIRouter(prefix="/tenants")
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
@@ -69,15 +74,15 @@ async def get_anonymous_user_path_api(
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
@@ -98,7 +103,7 @@ async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
@@ -111,6 +116,7 @@ async def login_as_anonymous_user(
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
@@ -124,52 +130,48 @@ async def login_as_anonymous_user(
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> None:
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when
|
||||
1) User has ended free trial without adding payment method
|
||||
2) User's card has declined
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
tenant_id = product_gating_request.tenant_id
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
settings = load_settings()
|
||||
settings.product_gating = product_gating_request.product_gating
|
||||
store_settings(settings)
|
||||
|
||||
if product_gating_request.notification:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_notification(None, product_gating_request.notification, db_session)
|
||||
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information", response_model=BillingInformation)
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation:
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
return BillingInformation(
|
||||
**fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
# Fetch tenant_id and current tenant's information
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/cloud-settings",
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
@@ -178,6 +180,22 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
@@ -186,7 +204,7 @@ async def impersonate_user(
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as tenant_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
@@ -210,8 +228,9 @@ async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
|
||||
@@ -6,6 +6,7 @@ import stripe
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -14,6 +15,19 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
|
||||
def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -27,7 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(tenant_id: str) -> dict:
|
||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -38,7 +52,7 @@ def fetch_billing_information(tenant_id: str) -> dict:
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
billing_info = response.json()
|
||||
billing_info = BillingInformation(**response.json())
|
||||
return billing_info
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.server.settings.models import GatingType
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
|
||||
class CheckoutSessionCreationRequest(BaseModel):
|
||||
@@ -15,15 +16,24 @@ class CreateTenantRequest(BaseModel):
|
||||
|
||||
class ProductGatingRequest(BaseModel):
|
||||
tenant_id: str
|
||||
product_gating: GatingType
|
||||
notification: NotificationType | None = None
|
||||
application_status: ApplicationStatus
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
subscribed: bool
|
||||
|
||||
|
||||
class BillingInformation(BaseModel):
|
||||
stripe_subscription_id: str
|
||||
status: str
|
||||
current_period_start: datetime
|
||||
current_period_end: datetime
|
||||
number_of_seats: int
|
||||
cancel_at_period_end: bool
|
||||
canceled_at: datetime | None
|
||||
trial_start: datetime | None
|
||||
trial_end: datetime | None
|
||||
seats: int
|
||||
subscription_status: str
|
||||
billing_start: str
|
||||
billing_end: str
|
||||
payment_method_enabled: bool
|
||||
|
||||
|
||||
@@ -48,3 +58,12 @@ class TenantDeletionPayload(BaseModel):
|
||||
|
||||
class AnonymousUserPath(BaseModel):
|
||||
anonymous_user_path: str | None
|
||||
|
||||
|
||||
class ProductGatingResponse(BaseModel):
|
||||
updated: bool
|
||||
error: str | None
|
||||
|
||||
|
||||
class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
51
backend/ee/onyx/server/tenants/product_gating.py
Normal file
51
backend/ee/onyx/server/tenants/product_gating.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import cast
|
||||
|
||||
from ee.onyx.configs.app_configs import GATED_TENANTS_KEY
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.setup import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
# Store the full status
|
||||
status_key = f"tenant:{tenant_id}:status"
|
||||
redis_client.set(status_key, status.value)
|
||||
|
||||
# Maintain the GATED_ACCESS set
|
||||
if status == ApplicationStatus.GATED_ACCESS:
|
||||
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
|
||||
else:
|
||||
redis_client.srem(GATED_TENANTS_KEY, tenant_id)
|
||||
|
||||
|
||||
def store_product_gating(tenant_id: str, application_status: ApplicationStatus) -> None:
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
settings.application_status = application_status
|
||||
store_settings(settings)
|
||||
|
||||
# Store gated tenant information in Redis
|
||||
update_tenant_gating(tenant_id, application_status)
|
||||
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to gate product")
|
||||
raise
|
||||
|
||||
|
||||
def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
@@ -24,6 +24,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
@@ -85,7 +86,8 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
@@ -116,7 +118,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
@@ -132,7 +134,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
|
||||
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits
|
||||
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
|
||||
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
@@ -51,8 +51,10 @@ def get_group_token_limit_settings(
|
||||
) -> list[TokenRateLimitDisplay]:
|
||||
return [
|
||||
TokenRateLimitDisplay.from_db(token_rate_limit)
|
||||
for token_rate_limit in fetch_user_group_token_rate_limits(
|
||||
db_session, group_id, user
|
||||
for token_rate_limit in fetch_user_group_token_rate_limits_for_user(
|
||||
db_session=db_session,
|
||||
group_id=group_id,
|
||||
user=user,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ class UserGroup(BaseModel):
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
),
|
||||
access_type=cc_pair_relationship.cc_pair.access_type,
|
||||
)
|
||||
for cc_pair_relationship in user_group_model.cc_pair_relationships
|
||||
if cc_pair_relationship.is_current
|
||||
|
||||
@@ -28,3 +28,9 @@ class EmbeddingModelTextType:
|
||||
@staticmethod
|
||||
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
||||
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
||||
|
||||
|
||||
class GPUStatus:
|
||||
CUDA = "cuda"
|
||||
MAC_MPS = "mps"
|
||||
NONE = "none"
|
||||
|
||||
@@ -12,6 +12,7 @@ import voyageai # type: ignore
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import aembedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
@@ -320,6 +321,7 @@ async def embed_text(
|
||||
prefix: str | None,
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
gpu_type: str = "UNKNOWN",
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
logger.error("Empty strings provided for embedding")
|
||||
@@ -373,8 +375,11 @@ async def embed_text(
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with provider {provider_type} in {elapsed:.2f}"
|
||||
f"event=embedding_provider "
|
||||
f"texts={len(texts)} "
|
||||
f"chars={total_chars} "
|
||||
f"provider={provider_type} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
elif model_name is not None:
|
||||
logger.info(
|
||||
@@ -403,6 +408,14 @@ async def embed_text(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with local model {model_name} in {elapsed:.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"event=embedding_model "
|
||||
f"texts={len(texts)} "
|
||||
f"chars={total_chars} "
|
||||
f"model={model_name} "
|
||||
f"gpu={gpu_type} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
@@ -455,8 +468,15 @@ async def litellm_rerank(
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def process_embed_request(
|
||||
async def route_bi_encoder_embed(
|
||||
request: Request,
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
return await process_embed_request(embed_request, request.app.state.gpu_type)
|
||||
|
||||
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
@@ -484,6 +504,7 @@ async def process_embed_request(
|
||||
api_url=embed_request.api_url,
|
||||
api_version=embed_request.api_version,
|
||||
prefix=prefix,
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except RateLimitError as e:
|
||||
|
||||
@@ -16,6 +16,7 @@ from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.management_endpoints import router as management_router
|
||||
from model_server.utils import get_gpu_type
|
||||
from onyx import __version__
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
@@ -58,12 +59,10 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
if torch.cuda.is_available():
|
||||
logger.notice("CUDA GPU is available")
|
||||
elif torch.backends.mps.is_available():
|
||||
logger.notice("Mac MPS is available")
|
||||
else:
|
||||
logger.notice("GPU is not available, using CPU")
|
||||
gpu_type = get_gpu_type()
|
||||
logger.notice(f"Torch GPU Detection: gpu_type={gpu_type}")
|
||||
|
||||
app.state.gpu_type = gpu_type
|
||||
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Response
|
||||
|
||||
from model_server.constants import GPUStatus
|
||||
from model_server.utils import get_gpu_type
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
|
||||
@@ -11,10 +13,7 @@ async def healthcheck() -> Response:
|
||||
|
||||
|
||||
@router.get("/gpu-status")
|
||||
async def gpu_status() -> dict[str, bool | str]:
|
||||
if torch.cuda.is_available():
|
||||
return {"gpu_available": True, "type": "cuda"}
|
||||
elif torch.backends.mps.is_available():
|
||||
return {"gpu_available": True, "type": "mps"}
|
||||
else:
|
||||
return {"gpu_available": False, "type": "none"}
|
||||
async def route_gpu_status() -> dict[str, bool | str]:
|
||||
gpu_type = get_gpu_type()
|
||||
gpu_available = gpu_type != GPUStatus.NONE
|
||||
return {"gpu_available": gpu_available, "type": gpu_type}
|
||||
|
||||
@@ -8,6 +8,9 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from model_server.constants import GPUStatus
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -58,3 +61,12 @@ def simple_log_function_time(
|
||||
return cast(F, wrapped_sync_func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_gpu_type() -> str:
|
||||
if torch.cuda.is_available():
|
||||
return GPUStatus.CUDA
|
||||
if torch.backends.mps.is_available():
|
||||
return GPUStatus.MAC_MPS
|
||||
|
||||
return GPUStatus.NONE
|
||||
|
||||
@@ -19,6 +19,9 @@ def prefix_external_group(ext_group_name: str) -> str:
|
||||
return f"external_group:{ext_group_name}"
|
||||
|
||||
|
||||
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""External groups may collide across sources, every source needs its own prefix."""
|
||||
return f"{source.value.upper()}_{ext_group_name}"
|
||||
def build_ext_group_name_for_onyx(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""
|
||||
External groups may collide across sources, every source needs its own prefix.
|
||||
NOTE: the name is lowercased to handle case sensitivity for group names
|
||||
"""
|
||||
return f"{source.value}_{ext_group_name}".lower()
|
||||
|
||||
97
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
97
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BasicState,
|
||||
input=BasicInput,
|
||||
output=BasicOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="prepare_tool_input",
|
||||
action=prepare_tool_input,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="choose_tool",
|
||||
action=choose_tool,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
|
||||
|
||||
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def should_continue(state: BasicState) -> str:
|
||||
return (
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state.tool_choice is None
|
||||
else "call_tool"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = BasicInput(unused=True)
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
config, _ = get_test_config(
|
||||
db_session=db_session,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_request=SearchRequest(query="How does onyx use FastAPI?"),
|
||||
)
|
||||
compiled_graph.invoke(input, config={"metadata": {"config": config}})
|
||||
35
backend/onyx/agents/agent_search/basic/states.py
Normal file
35
backend/onyx/agents/agent_search/basic/states.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
|
||||
# States contain values that change over the course of graph execution,
|
||||
# Config is for values that are set at the start and never change.
|
||||
# If you are using a value from the config and realize it needs to change,
|
||||
# you should add it to the state and use/update the version in the state.
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class BasicInput(BaseModel):
|
||||
# Langgraph needs a nonempty input, but we pass in all static
|
||||
# data through a RunnableConfig.
|
||||
unused: bool = True
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class BasicOutput(TypedDict):
|
||||
tool_call_chunk: AIMessageChunk
|
||||
|
||||
|
||||
## Graph State
|
||||
class BasicState(
|
||||
BasicInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
):
|
||||
pass
|
||||
64
backend/onyx/agents/agent_search/basic/utils.py
Normal file
64
backend/onyx/agents/agent_search/basic/utils.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
PassThroughAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
messages: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
if final_search_results and displayed_search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
full_answer = ""
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for message in messages:
|
||||
answer_piece = message.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# this is only used for logging, so fine to
|
||||
# just add the string representation
|
||||
answer_piece = str(answer_piece)
|
||||
full_answer += answer_piece
|
||||
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
tool_call_chunk += message # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(message, []):
|
||||
write_custom_event(
|
||||
"basic_response",
|
||||
response_part,
|
||||
writer,
|
||||
)
|
||||
|
||||
logger.debug(f"Full answer: {full_answer}")
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
||||
20
backend/onyx/agents/agent_search/core_state.py
Normal file
20
backend/onyx/agents/agent_search/core_state.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class SubgraphCoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
@@ -0,0 +1,31 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_retrieval(state: SubQuestionAnsweringInput) -> Send | Hashable:
|
||||
"""
|
||||
LangGraph edge to send a sub-question to the expanded retrieval.
|
||||
"""
|
||||
edge_start_time = datetime.now()
|
||||
|
||||
return Send(
|
||||
"initial_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
base_search=False,
|
||||
sub_question_id=state.question_id,
|
||||
log_messages=[f"{edge_start_time} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.edges import (
|
||||
send_to_expanded_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
|
||||
check_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
|
||||
format_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
|
||||
generate_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
|
||||
ingest_retrieved_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_query_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph sub-graph builder for the initial individual sub-answer generation.
|
||||
"""
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=SubQuestionAnsweringInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
# The sub-graph that executes the expanded retrieval process for a sub-question
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="initial_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
|
||||
# The node that ingests the retrieved documents and puts them into the proper
|
||||
# state keys.
|
||||
graph.add_node(
|
||||
node="ingest_retrieval",
|
||||
action=ingest_retrieved_documents,
|
||||
)
|
||||
|
||||
# The node that generates the sub-answer
|
||||
graph.add_node(
|
||||
node="generate_sub_answer",
|
||||
action=generate_sub_answer,
|
||||
)
|
||||
|
||||
# The node that checks the sub-answer
|
||||
graph.add_node(
|
||||
node="answer_check",
|
||||
action=check_sub_answer,
|
||||
)
|
||||
|
||||
# The node that formats the sub-answer for the following initial answer generation
|
||||
graph.add_node(
|
||||
node="format_answer",
|
||||
action=format_sub_answer,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_retrieval,
|
||||
path_map=["initial_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="initial_sub_question_expanded_retrieval",
|
||||
end_key="ingest_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_retrieval",
|
||||
end_key="generate_sub_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="generate_sub_answer",
|
||||
end_key="answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_check",
|
||||
end_key="format_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
graph_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = SubQuestionAnsweringInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": graph_config}},
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,134 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnswerCheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
|
||||
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
|
||||
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def check_sub_answer(
|
||||
state: AnswerQuestionState, config: RunnableConfig
|
||||
) -> SubQuestionAnswerCheckUpdate:
|
||||
"""
|
||||
LangGraph node to check the quality of the sub-answer. The answer
|
||||
is represented as a boolean value.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
if state.answer == UNKNOWN_ANSWER:
|
||||
return SubQuestionAnswerCheckUpdate(
|
||||
answer_quality=False,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="check sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result="unknown answer",
|
||||
)
|
||||
],
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_ANSWER_CHECK_PROMPT.format(
|
||||
question=state.question,
|
||||
base_answer=state.answer,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: BaseMessage | None = None
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
|
||||
)
|
||||
|
||||
quality_str: str = cast(str, response.content)
|
||||
answer_quality = binary_string_test(
|
||||
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
|
||||
)
|
||||
log_result = f"Answer quality: {quality_str}"
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
answer_quality = True
|
||||
log_result = agent_error.error_result
|
||||
logger.error("LLM Timeout Error - check sub answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
|
||||
answer_quality = True
|
||||
log_result = agent_error.error_result
|
||||
logger.error("LLM Rate Limit Error - check sub answer")
|
||||
|
||||
return SubQuestionAnswerCheckUpdate(
|
||||
answer_quality=answer_quality,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="check sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
|
||||
|
||||
def format_sub_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
|
||||
"""
|
||||
LangGraph node to generate the sub-answer format.
|
||||
"""
|
||||
return AnswerQuestionOutput(
|
||||
answer_results=[
|
||||
SubQuestionAnswerResults(
|
||||
question=state.question,
|
||||
question_id=state.question_id,
|
||||
verified_high_quality=state.answer_quality,
|
||||
answer=state.answer,
|
||||
sub_query_retrieval_results=state.expanded_retrieval_results,
|
||||
verified_reranked_documents=state.verified_reranked_documents,
|
||||
context_documents=state.context_documents,
|
||||
cited_documents=state.cited_documents,
|
||||
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,203 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnswerGenerationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
dedup_sort_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
LLM_ANSWER_ERROR_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def generate_sub_answer(
|
||||
state: AnswerQuestionState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubQuestionAnswerGenerationUpdate:
|
||||
"""
|
||||
LangGraph node to generate a sub-answer.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = state.question
|
||||
state.verified_reranked_documents
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
|
||||
|
||||
context_docs = dedup_sort_inference_section_list(context_docs)
|
||||
|
||||
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
|
||||
graph_config.inputs.search_request.persona
|
||||
).contextualized_prompt
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
cited_documents: list = []
|
||||
log_results = "No documents retrieved"
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=answer_str,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
else:
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
msg = build_sub_question_answer_prompt(
|
||||
question=question,
|
||||
original_question=graph_config.inputs.search_request.query,
|
||||
docs=context_docs,
|
||||
persona_specification=persona_contextualized_prompt,
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: list[str] = []
|
||||
|
||||
def stream_sub_answer() -> list[str]:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
|
||||
stream_sub_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate sub answer")
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate sub answer")
|
||||
|
||||
if agent_error:
|
||||
answer_str = LLM_ANSWER_ERROR_MESSAGE
|
||||
cited_documents = []
|
||||
log_results = (
|
||||
agent_error.error_result
|
||||
or "Sub-answer generation failed due to LLM error"
|
||||
)
|
||||
|
||||
else:
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_documents = [
|
||||
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
|
||||
]
|
||||
log_results = None
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_ANSWER,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
return SubQuestionAnswerGenerationUpdate(
|
||||
answer=answer_str,
|
||||
cited_documents=cited_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="generate sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result=log_results or "",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionRetrievalIngestionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
|
||||
|
||||
def ingest_retrieved_documents(
|
||||
state: ExpandedRetrievalOutput,
|
||||
) -> SubQuestionRetrievalIngestionUpdate:
|
||||
"""
|
||||
LangGraph node to ingest the retrieved documents to format it for the sub-answer.
|
||||
"""
|
||||
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = [AgentChunkRetrievalStats()]
|
||||
|
||||
return SubQuestionRetrievalIngestionUpdate(
|
||||
expanded_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
|
||||
verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
|
||||
context_documents=state.expanded_retrieval_result.context_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
@@ -0,0 +1,73 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.deep_search.main.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
## Update States
|
||||
class SubQuestionAnswerCheckUpdate(LoggerUpdate, BaseModel):
|
||||
answer_quality: bool = False
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class SubQuestionAnswerGenerationUpdate(LoggerUpdate, BaseModel):
|
||||
answer: str = ""
|
||||
log_messages: list[str] = []
|
||||
cited_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
# answer_stat: AnswerStats
|
||||
|
||||
|
||||
class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
|
||||
expanded_retrieval_results: list[QueryRetrievalResult] = []
|
||||
verified_reranked_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class SubQuestionAnsweringInput(SubgraphCoreState):
|
||||
question: str
|
||||
question_id: str
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class AnswerQuestionState(
|
||||
SubQuestionAnsweringInput,
|
||||
SubQuestionAnswerGenerationUpdate,
|
||||
SubQuestionAnswerCheckUpdate,
|
||||
SubQuestionRetrievalIngestionUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class AnswerQuestionOutput(LoggerUpdate, BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
answer_results: Annotated[list[SubQuestionAnswerResults], add] = []
|
||||
@@ -0,0 +1,50 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: SubQuestionRetrievalState,
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the initial sub-question answering. If there are no sub-questions,
|
||||
we send empty answers to the initial answer generation, and that answer would be generated
|
||||
solely based on the documents retrieved for the original question.
|
||||
"""
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.initial_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
SubQuestionAnsweringInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_num + 1),
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_num, question in enumerate(state.initial_sub_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,96 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.generate_initial_answer import (
|
||||
generate_initial_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.validate_initial_answer import (
|
||||
validate_initial_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.graph_builder import (
|
||||
generate_sub_answers_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.graph_builder import (
|
||||
retrieve_orig_question_docs_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_initial_answer_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the initial answer generation.
|
||||
"""
|
||||
graph = StateGraph(
|
||||
state_schema=SubQuestionRetrievalState,
|
||||
input=SubQuestionRetrievalInput,
|
||||
)
|
||||
|
||||
# The sub-graph that generates the initial sub-answers
|
||||
generate_sub_answers = generate_sub_answers_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="generate_sub_answers_subgraph",
|
||||
action=generate_sub_answers,
|
||||
)
|
||||
|
||||
# The sub-graph that retrieves the original question documents. This is run
|
||||
# in parallel with the sub-answer generation process
|
||||
retrieve_orig_question_docs = retrieve_orig_question_docs_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="retrieve_orig_question_docs_subgraph_wrapper",
|
||||
action=retrieve_orig_question_docs,
|
||||
)
|
||||
|
||||
# Node that generates the initial answer using the results of the previous
|
||||
# two sub-graphs
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
|
||||
# Node that validates the initial answer
|
||||
graph.add_node(
|
||||
node="validate_initial_answer",
|
||||
action=validate_initial_answer,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="retrieve_orig_question_docs_subgraph_wrapper",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="generate_sub_answers_subgraph",
|
||||
)
|
||||
|
||||
# Wait for both, the original question docs and the sub-answers to be generated before proceeding
|
||||
graph.add_edge(
|
||||
start_key=[
|
||||
"retrieve_orig_question_docs_subgraph_wrapper",
|
||||
"generate_sub_answers_subgraph",
|
||||
],
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key="validate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="validate_initial_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,419 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search.main.operations import (
|
||||
calculate_initial_agent_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
get_prompt_enrichment_components,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_deduplicated_structured_subquestion_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The initial answer could not be generated.",
|
||||
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
|
||||
general_error="General LLM Error. The initial answer could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def generate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InitialAnswerUpdate:
|
||||
"""
|
||||
LangGraph node to generate the initial answer, using the initial sub-questions/sub-answers and the
|
||||
documents retrieved for the original question.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
|
||||
|
||||
# get all documents cited in sub-questions
|
||||
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
|
||||
state.sub_question_results
|
||||
)
|
||||
|
||||
orig_question_retrieval_documents = state.orig_question_retrieved_documents
|
||||
|
||||
consolidated_context_docs = structured_subquestion_docs.cited_documents
|
||||
counter = 0
|
||||
for original_doc_number, original_doc in enumerate(
|
||||
orig_question_retrieval_documents
|
||||
):
|
||||
if original_doc_number not in structured_subquestion_docs.cited_documents:
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
):
|
||||
consolidated_context_docs.append(original_doc)
|
||||
counter += 1
|
||||
|
||||
# sort docs by their scores - though the scores refer to different questions
|
||||
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
|
||||
|
||||
sub_questions: list[str] = []
|
||||
|
||||
# Create the list of documents to stream out. Start with the
|
||||
# ones that wil be in the context (or, if len == 0, use docs
|
||||
# that were retrieved for the original question)
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=relevant_docs,
|
||||
context_documents=structured_subquestion_docs.context_documents,
|
||||
original_question_docs=orig_question_retrieval_documents,
|
||||
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER,
|
||||
)
|
||||
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
||||
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
|
||||
relevance_list = relevance_from_docs(
|
||||
answer_generation_documents.streaming_documents
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_num=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if len(answer_generation_documents.context_documents) == 0:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=UNKNOWN_ANSWER,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
answer = UNKNOWN_ANSWER
|
||||
initial_agent_stats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
else:
|
||||
sub_question_answer_results = state.sub_question_results
|
||||
|
||||
# Collect the sub-questions and sub-answers and construct an appropriate
|
||||
# prompt string.
|
||||
# Consider replacing by a function.
|
||||
answered_sub_questions: list[str] = []
|
||||
all_sub_questions: list[str] = [] # Separate list for tracking all questions
|
||||
|
||||
for idx, sub_question_answer_result in enumerate(
|
||||
sub_question_answer_results, start=1
|
||||
):
|
||||
all_sub_questions.append(sub_question_answer_result.question)
|
||||
|
||||
is_valid_answer = (
|
||||
sub_question_answer_result.verified_high_quality
|
||||
and sub_question_answer_result.answer
|
||||
and sub_question_answer_result.answer != UNKNOWN_ANSWER
|
||||
)
|
||||
|
||||
if is_valid_answer:
|
||||
answered_sub_questions.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=sub_question_answer_result.question,
|
||||
sub_answer=sub_question_answer_result.answer,
|
||||
sub_question_num=idx,
|
||||
)
|
||||
)
|
||||
|
||||
sub_question_answer_str = (
|
||||
"\n\n------\n\n".join(answered_sub_questions)
|
||||
if answered_sub_questions
|
||||
else ""
|
||||
)
|
||||
|
||||
# Use the appropriate prompt based on whether there are sub-questions.
|
||||
base_prompt = (
|
||||
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
if answered_sub_questions
|
||||
else INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
sub_questions = all_sub_questions # Replace the original assignment
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
|
||||
doc_context = format_docs(answer_generation_documents.context_documents)
|
||||
doc_context = trim_prompt_piece(
|
||||
config=model.config,
|
||||
prompt_piece=doc_context,
|
||||
reserved_str=(
|
||||
base_prompt
|
||||
+ sub_question_answer_str
|
||||
+ prompt_enrichment_components.persona_prompts.contextualized_prompt
|
||||
+ prompt_enrichment_components.history
|
||||
+ prompt_enrichment_components.date_str
|
||||
),
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=doc_context,
|
||||
persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt,
|
||||
history=prompt_enrichment_components.history,
|
||||
date_prompt=prompt_enrichment_components.date_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_initial_answer() -> list[str]:
|
||||
response: list[str] = []
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
stream_initial_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate initial answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate initial answer")
|
||||
|
||||
if agent_error:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
StreamingError(
|
||||
error=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=None,
|
||||
answer_error=AgentErrorLog(
|
||||
error_message=agent_error.error_message or "An LLM error occurred",
|
||||
error_type=agent_error.error_type,
|
||||
error_result=agent_error.error_result,
|
||||
),
|
||||
initial_agent_stats=None,
|
||||
generated_sub_questions=sub_questions,
|
||||
agent_base_end_time=None,
|
||||
agent_base_metrics=None,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="generate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result=agent_error.error_result or "An LLM error occurred",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
|
||||
dispatch_main_answer_stop_info(0, writer)
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
initial_agent_stats = calculate_initial_agent_stats(
|
||||
state.sub_question_results, state.orig_question_retrieval_stats
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
|
||||
)
|
||||
|
||||
if initial_agent_stats:
|
||||
logger.debug(initial_agent_stats.original_question)
|
||||
logger.debug(initial_agent_stats.sub_questions)
|
||||
logger.debug(initial_agent_stats.agent_effectiveness)
|
||||
|
||||
agent_base_end_time = datetime.now()
|
||||
|
||||
if agent_base_end_time and state.agent_start_time:
|
||||
duration_s = (agent_base_end_time - state.agent_start_time).total_seconds()
|
||||
else:
|
||||
duration_s = None
|
||||
|
||||
agent_base_metrics = AgentBaseMetrics(
|
||||
num_verified_documents_total=len(relevant_docs),
|
||||
num_verified_documents_core=state.orig_question_retrieval_stats.verified_count,
|
||||
verified_avg_score_core=state.orig_question_retrieval_stats.verified_avg_scores,
|
||||
num_verified_documents_base=initial_agent_stats.sub_questions.get(
|
||||
"num_verified_documents"
|
||||
),
|
||||
verified_avg_score_base=initial_agent_stats.sub_questions.get(
|
||||
"verified_avg_score"
|
||||
),
|
||||
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio"
|
||||
),
|
||||
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"support_ratio"
|
||||
),
|
||||
duration_s=duration_s,
|
||||
)
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=initial_agent_stats,
|
||||
generated_sub_questions=sub_questions,
|
||||
agent_base_end_time=agent_base_end_time,
|
||||
agent_base_metrics=agent_base_metrics,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="generate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerQualityUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def validate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
) -> InitialAnswerQualityUpdate:
|
||||
"""
|
||||
Check whether the initial answer sufficiently addresses the original user question.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
|
||||
)
|
||||
|
||||
verdict = True # not actually required as already streamed out. Refinement will do similar
|
||||
|
||||
return InitialAnswerQualityUpdate(
|
||||
initial_answer_quality_eval=verdict,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="validate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user