forked from github/onyx
Compare commits
453 Commits
v0.24.0-cl
...
nightly-la
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3702b76b6 | ||
|
|
bb239d574c | ||
|
|
172e5f0e24 | ||
|
|
26b026fb88 | ||
|
|
870629e8a9 | ||
|
|
a547112321 | ||
|
|
da5a94815e | ||
|
|
e024472b74 | ||
|
|
e74855e633 | ||
|
|
e4c26a933d | ||
|
|
36c96f2d98 | ||
|
|
1ea94dcd8d | ||
|
|
2b1c5a0755 | ||
|
|
82b5f806ab | ||
|
|
6340c517d1 | ||
|
|
3baae2d4f0 | ||
|
|
d7c223ddd4 | ||
|
|
df4917243b | ||
|
|
a79ab713ce | ||
|
|
d1f7cee959 | ||
|
|
a3f41e20da | ||
|
|
458ed93da0 | ||
|
|
273d073bd7 | ||
|
|
9455c8e5ae | ||
|
|
d45d4389a0 | ||
|
|
bd901c0da1 | ||
|
|
2192605c95 | ||
|
|
d248d2f4e9 | ||
|
|
331c53871a | ||
|
|
f62d0d9144 | ||
|
|
427945e757 | ||
|
|
e55cdc6250 | ||
|
|
6a01db9ff2 | ||
|
|
82e9df5c22 | ||
|
|
16c2ef2852 | ||
|
|
224a70eea9 | ||
|
|
c457982120 | ||
|
|
0649748da2 | ||
|
|
ddceddaa28 | ||
|
|
c6733a5026 | ||
|
|
7db744a5de | ||
|
|
cd2a8b0def | ||
|
|
f15bc26cd6 | ||
|
|
65f35f0293 | ||
|
|
4e3e608249 | ||
|
|
719a092a12 | ||
|
|
6a8fde7eb1 | ||
|
|
4fdd0812a0 | ||
|
|
4913dc1e85 | ||
|
|
4a43a9642e | ||
|
|
cc48a0c38e | ||
|
|
01ccfd2df7 | ||
|
|
36d75786ee | ||
|
|
f9bc38ba65 | ||
|
|
3da283221d | ||
|
|
90568d3bbb | ||
|
|
7955ca938c | ||
|
|
f5d357eb28 | ||
|
|
d83f616214 | ||
|
|
275c1bec3d | ||
|
|
7d1ef912e8 | ||
|
|
2fe1d4c373 | ||
|
|
2396ad309e | ||
|
|
0b13ef963a | ||
|
|
83073f3ded | ||
|
|
439a27a775 | ||
|
|
91773a4789 | ||
|
|
185beca648 | ||
|
|
2dc564c8df | ||
|
|
b259f53972 | ||
|
|
f8beb08e2f | ||
|
|
83c88c7cf6 | ||
|
|
2372dd40e0 | ||
|
|
5cb6bafe81 | ||
|
|
a0309b31c7 | ||
|
|
0fd268dba7 | ||
|
|
f345da7487 | ||
|
|
f2dacf03f1 | ||
|
|
e0fef50cf0 | ||
|
|
6ba3eeefa5 | ||
|
|
aa158abaa9 | ||
|
|
255c2af1d6 | ||
|
|
9ece3b0310 | ||
|
|
9e3aca03a7 | ||
|
|
dbd5d4d8f1 | ||
|
|
cdb97c3ce4 | ||
|
|
f30ced31a9 | ||
|
|
6cc6c43234 | ||
|
|
224d934cf4 | ||
|
|
8ecdc61ad3 | ||
|
|
08161db7ea | ||
|
|
b139764631 | ||
|
|
2b23dbde8d | ||
|
|
2dec009d63 | ||
|
|
91eadae353 | ||
|
|
8bff616e27 | ||
|
|
2c049e170f | ||
|
|
23e6d7ef3c | ||
|
|
ed81e75edd | ||
|
|
de22fc3a58 | ||
|
|
009b7f60f1 | ||
|
|
9d997e20df | ||
|
|
e6423c4541 | ||
|
|
cb969ad06a | ||
|
|
c4076d16b6 | ||
|
|
04a607a718 | ||
|
|
c1e1aa9dfd | ||
|
|
1ed7abae6e | ||
|
|
cf4855822b | ||
|
|
e242b1319c | ||
|
|
eba4b6620e | ||
|
|
3534515e11 | ||
|
|
5602ff8666 | ||
|
|
2fc70781b4 | ||
|
|
f76b4dec4c | ||
|
|
a5a516fa8a | ||
|
|
811a198134 | ||
|
|
5867ab1d7d | ||
|
|
dd6653eb1f | ||
|
|
db457ef432 | ||
|
|
de7fe939b2 | ||
|
|
38114d9542 | ||
|
|
32f20f2e2e | ||
|
|
3dd27099f7 | ||
|
|
91c4d43a80 | ||
|
|
a63ba1bb03 | ||
|
|
7b6189e74c | ||
|
|
ba423e5773 | ||
|
|
fe029eccae | ||
|
|
ea72af7698 | ||
|
|
17abf85533 | ||
|
|
3bd162acb9 | ||
|
|
664ce441eb | ||
|
|
6863fbee54 | ||
|
|
bb98088b80 | ||
|
|
ce8cb1112a | ||
|
|
a605bd4ca4 | ||
|
|
0e8b5af619 | ||
|
|
46f3af4f68 | ||
|
|
2af64ebf4c | ||
|
|
0eb1824158 | ||
|
|
e0a9a6fb66 | ||
|
|
fe194076c2 | ||
|
|
55dc24fd27 | ||
|
|
da02962a67 | ||
|
|
9bc62cc803 | ||
|
|
bf6705a9a5 | ||
|
|
df2fef3383 | ||
|
|
8cec3448d7 | ||
|
|
b81687995e | ||
|
|
87c2253451 | ||
|
|
297c2957b4 | ||
|
|
bacee0d09d | ||
|
|
297720c132 | ||
|
|
bd4bd00cef | ||
|
|
07c482f727 | ||
|
|
cf193dee29 | ||
|
|
1b47fa2700 | ||
|
|
e1a305d18a | ||
|
|
e2233d22c9 | ||
|
|
20d1175312 | ||
|
|
7117774287 | ||
|
|
77f2660bb2 | ||
|
|
1b2f4f3b87 | ||
|
|
d85b55a9d2 | ||
|
|
e2bae5a2d9 | ||
|
|
cc9c76c4fb | ||
|
|
258e08abcd | ||
|
|
67047e42a7 | ||
|
|
146628e734 | ||
|
|
c1d4b08132 | ||
|
|
f3f47d0709 | ||
|
|
fe26a1bfcc | ||
|
|
554cd0f891 | ||
|
|
f87d3e9849 | ||
|
|
72cdada893 | ||
|
|
c442ebaff6 | ||
|
|
56f16d107e | ||
|
|
0157ae099a | ||
|
|
565fb42457 | ||
|
|
a50a8b4a12 | ||
|
|
4baf4e7d96 | ||
|
|
8b7ab2eb66 | ||
|
|
1f75f3633e | ||
|
|
650884d76a | ||
|
|
8722bdb414 | ||
|
|
71037678c3 | ||
|
|
68de1015e1 | ||
|
|
e2b3a6e144 | ||
|
|
4f04b09efa | ||
|
|
5c4f44d258 | ||
|
|
19652ad60e | ||
|
|
70c96b6ab3 | ||
|
|
65076b916f | ||
|
|
06bc0e51db | ||
|
|
508b456b40 | ||
|
|
bf1e2a2661 | ||
|
|
991d5e4203 | ||
|
|
d21f012b04 | ||
|
|
86b7beab01 | ||
|
|
b4eaa81d8b | ||
|
|
ff2a4c8723 | ||
|
|
51027fd259 | ||
|
|
7e3fd2b12a | ||
|
|
d2fef6f0b7 | ||
|
|
bd06147d26 | ||
|
|
1f3cc9ed6e | ||
|
|
6086d9e51a | ||
|
|
e0de24f64e | ||
|
|
08b6b1f8b3 | ||
|
|
afed1a4b37 | ||
|
|
bca18cacdf | ||
|
|
335db91803 | ||
|
|
67c488ff1f | ||
|
|
deb7f13962 | ||
|
|
e2d3d65c60 | ||
|
|
b78a6834f5 | ||
|
|
4abe90aa2c | ||
|
|
de9568844b | ||
|
|
34268f9806 | ||
|
|
ed75678837 | ||
|
|
3bb58a3dd3 | ||
|
|
4b02feef31 | ||
|
|
6a4d49f02e | ||
|
|
d1736187d3 | ||
|
|
0e79b96091 | ||
|
|
ae302d473d | ||
|
|
feca4fda78 | ||
|
|
f7ed7cd3cd | ||
|
|
8377ab3ef2 | ||
|
|
95c23bf870 | ||
|
|
e49fb8f56d | ||
|
|
adf48de652 | ||
|
|
bca2500438 | ||
|
|
89f925662f | ||
|
|
b64c6d5d40 | ||
|
|
36c63950a6 | ||
|
|
3f31340e6f | ||
|
|
6ac2258c2e | ||
|
|
b4d3b43e8a | ||
|
|
ca281b71e3 | ||
|
|
9bd5a1de7a | ||
|
|
d3c5a4fba0 | ||
|
|
f50006ee63 | ||
|
|
e0092024af | ||
|
|
675ef524b0 | ||
|
|
240367c775 | ||
|
|
f0ed063860 | ||
|
|
bcf0ef0c87 | ||
|
|
0c7a245a46 | ||
|
|
583d82433a | ||
|
|
391e710b6e | ||
|
|
004e56a91b | ||
|
|
103300798f | ||
|
|
8349d6f0ea | ||
|
|
cd63bf6da9 | ||
|
|
5f03e85195 | ||
|
|
cbdbfcab5e | ||
|
|
6918611287 | ||
|
|
b0639add8f | ||
|
|
7af10308d7 | ||
|
|
5e14f23507 | ||
|
|
0bf3a5c609 | ||
|
|
82724826ce | ||
|
|
f9e061926a | ||
|
|
8afd07ff7a | ||
|
|
6523a38255 | ||
|
|
264878a1c9 | ||
|
|
e480946f8a | ||
|
|
be25b1efbd | ||
|
|
204493439b | ||
|
|
106c685afb | ||
|
|
809122fec3 | ||
|
|
c8741d8e9c | ||
|
|
885f01e6a7 | ||
|
|
3180a13cf1 | ||
|
|
630ac31355 | ||
|
|
80de62f47d | ||
|
|
c75d42aa99 | ||
|
|
e1766bca55 | ||
|
|
211102f5f0 | ||
|
|
c46cc4666f | ||
|
|
0b2536b82b | ||
|
|
600a86f11d | ||
|
|
4d97a03935 | ||
|
|
5d7169f244 | ||
|
|
df9329009c | ||
|
|
e74a0398dc | ||
|
|
94c5822cb7 | ||
|
|
dedac55098 | ||
|
|
2bbab5cefe | ||
|
|
4bef718fad | ||
|
|
e7376e9dc2 | ||
|
|
8d5136fe8b | ||
|
|
3272050975 | ||
|
|
1960714042 | ||
|
|
5bddb2632e | ||
|
|
5cd055dab8 | ||
|
|
fa32b7f21e | ||
|
|
37f7227000 | ||
|
|
c1f9a9d122 | ||
|
|
045b7cc7e2 | ||
|
|
970e07a93b | ||
|
|
d463a3f213 | ||
|
|
4ba44c5e48 | ||
|
|
6f8176092e | ||
|
|
198ec417ba | ||
|
|
fbdf7798cf | ||
|
|
7bd9c856aa | ||
|
|
948c719d73 | ||
|
|
42572479cb | ||
|
|
accd363d3f | ||
|
|
8cf754a8b6 | ||
|
|
bf79220ac0 | ||
|
|
4c9dc14e65 | ||
|
|
f8621f7ea9 | ||
|
|
e0e08427b9 | ||
|
|
169df994da | ||
|
|
d83eaf2efb | ||
|
|
4e1e30f751 | ||
|
|
561f8c9c53 | ||
|
|
f625a4d0a7 | ||
|
|
746d4b6d3c | ||
|
|
fdd48c6588 | ||
|
|
23a04f7b9c | ||
|
|
b7b0dde7aa | ||
|
|
c40b78c7e9 | ||
|
|
33c0133cc7 | ||
|
|
cca5bc13dc | ||
|
|
d5ecaea8e7 | ||
|
|
b6d3c38ca9 | ||
|
|
b5fc1b4323 | ||
|
|
a1a9c42b0b | ||
|
|
e689e143e5 | ||
|
|
a7a168d934 | ||
|
|
69f47fc3e3 | ||
|
|
8a87140b4f | ||
|
|
53db0ddc4d | ||
|
|
087085403f | ||
|
|
c040b1cb47 | ||
|
|
1f4d0716b9 | ||
|
|
aa4993873f | ||
|
|
ce031c4394 | ||
|
|
d4dadb0dda | ||
|
|
0ded5813cd | ||
|
|
83137a19fb | ||
|
|
be66f8dbeb | ||
|
|
70baecb402 | ||
|
|
c27ba6bad4 | ||
|
|
61fda6ec58 | ||
|
|
2c93841eaa | ||
|
|
879db3391b | ||
|
|
6dff3d41fa | ||
|
|
e399eeb014 | ||
|
|
b133e8fcf0 | ||
|
|
ba9b24a477 | ||
|
|
cc7fb625a6 | ||
|
|
2b812b7d7d | ||
|
|
c5adbe4180 | ||
|
|
21dc3a2456 | ||
|
|
9631f373f0 | ||
|
|
cbd4d46fa5 | ||
|
|
dc4b9bc003 | ||
|
|
affb9e6941 | ||
|
|
dc542fd7fa | ||
|
|
85eeb21b77 | ||
|
|
4bb3ee03a0 | ||
|
|
1bb23d6837 | ||
|
|
f447359815 | ||
|
|
851e0b05f2 | ||
|
|
094cc940a4 | ||
|
|
51be9000bb | ||
|
|
80ecdb711d | ||
|
|
a599176bbf | ||
|
|
e0341b4c8a | ||
|
|
4c93fd448f | ||
|
|
84d916e210 | ||
|
|
f57ed2a8dd | ||
|
|
713889babf | ||
|
|
58c641d8ec | ||
|
|
94985e24c6 | ||
|
|
4c71a5f5ff | ||
|
|
b19e3a500b | ||
|
|
267fe027f5 | ||
|
|
0d4d8c0d64 | ||
|
|
6f9d8c0cff | ||
|
|
5031096a2b | ||
|
|
797e113000 | ||
|
|
edc2892785 | ||
|
|
ef4d5dcec3 | ||
|
|
0b5e3e5ee4 | ||
|
|
f5afb3621e | ||
|
|
9f72826143 | ||
|
|
ab7a4184df | ||
|
|
16a14bac89 | ||
|
|
baaf31513c | ||
|
|
0b01d7f848 | ||
|
|
23ff3476bc | ||
|
|
0c7ba8e2ac | ||
|
|
dad99cbec7 | ||
|
|
3e78c2f087 | ||
|
|
e822afdcfa | ||
|
|
b824951c89 | ||
|
|
ca20e527fc | ||
|
|
c8e65cce1e | ||
|
|
6c349687da | ||
|
|
3b64793d4b | ||
|
|
9dbe12cea8 | ||
|
|
e78637d632 | ||
|
|
cac03c07f7 | ||
|
|
95dabfaa18 | ||
|
|
e92c418e0f | ||
|
|
0593d045bf | ||
|
|
fff701b0bb | ||
|
|
0087a32d8b | ||
|
|
06312e485c | ||
|
|
e0f5b95cfc | ||
|
|
10bc072b4b | ||
|
|
b60884d3af | ||
|
|
95ae6d300c | ||
|
|
b76e4754bf | ||
|
|
b1085039ca | ||
|
|
d64f479c9f | ||
|
|
fd735c9a3f | ||
|
|
2282f6a42e | ||
|
|
0262002883 | ||
|
|
01ca9dc85d | ||
|
|
0735a98284 | ||
|
|
8d2e170fc4 | ||
|
|
f3e2795e69 | ||
|
|
30d9ce1310 | ||
|
|
2af2b7f130 | ||
|
|
9d41820363 | ||
|
|
a44f289aed | ||
|
|
9c078b3acf | ||
|
|
349f2c6ed6 | ||
|
|
0dc851a1cf | ||
|
|
f27fe068e8 | ||
|
|
f836cff935 | ||
|
|
312e3b92bc | ||
|
|
0cc0964231 | ||
|
|
b82278e685 | ||
|
|
daa1746b4a | ||
|
|
d8068f0a68 | ||
|
|
d91f776c2d | ||
|
|
a01135581f | ||
|
|
392b87fb4f | ||
|
|
551a05aef0 | ||
|
|
6b9d0b5af9 | ||
|
|
b8f3ad3e5d | ||
|
|
b19515e25d | ||
|
|
913f7cc7d4 | ||
|
|
84566debab |
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@@ -1 +1,3 @@
|
||||
* @onyx-dot-app/onyx-core-team
|
||||
# Helm charts Owners
|
||||
/helm/ @justin-tahara
|
||||
|
||||
19
.github/actions/custom-build-and-push/action.yml
vendored
19
.github/actions/custom-build-and-push/action.yml
vendored
@@ -35,6 +35,16 @@ inputs:
|
||||
cache-to:
|
||||
description: 'Cache destinations'
|
||||
required: false
|
||||
outputs:
|
||||
description: 'Output destinations'
|
||||
required: false
|
||||
provenance:
|
||||
description: 'Generate provenance attestation'
|
||||
required: false
|
||||
default: 'false'
|
||||
build-args:
|
||||
description: 'Build arguments'
|
||||
required: false
|
||||
retry-wait-time:
|
||||
description: 'Time to wait before attempt 2 in seconds'
|
||||
required: false
|
||||
@@ -62,6 +72,9 @@ runs:
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Wait before attempt 2
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
@@ -85,6 +98,9 @@ runs:
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Wait before attempt 3
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
|
||||
@@ -108,6 +124,9 @@ runs:
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Report failure
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success' && steps.buildx3.outcome != 'success'
|
||||
|
||||
@@ -7,8 +7,10 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
@@ -40,9 +42,11 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -60,7 +64,7 @@ jobs:
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
@@ -111,6 +115,11 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
@@ -133,13 +142,25 @@ jobs:
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
# Security: Using pinned digest (0.65.0@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436)
|
||||
# Security: No Docker socket mount needed for remote registry scanning
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL onyxdotapp/onyx-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
trivyignores: ./backend/.trivyignore
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
--ignorefile /tmp/.trivyignore \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -4,11 +4,10 @@ name: Build and Push Cloud Web Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- "*cloud*"
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
DEPLOYMENT: cloud
|
||||
|
||||
jobs:
|
||||
@@ -39,9 +38,10 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -54,7 +54,7 @@ jobs:
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
@@ -112,6 +112,10 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
@@ -135,10 +139,20 @@ jobs:
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -7,10 +7,12 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDKIT_PROGRESS: plain
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -78,7 +80,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and Push AMD64
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
@@ -97,7 +99,7 @@ jobs:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
[runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-arm64"]
|
||||
env:
|
||||
PLATFORM_PAIR: linux-arm64
|
||||
steps:
|
||||
@@ -124,7 +126,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and Push ARM64
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
@@ -162,11 +164,20 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout: "10m"
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -9,9 +9,24 @@ env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
DEPLOYMENT: standalone
|
||||
|
||||
|
||||
jobs:
|
||||
precheck:
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
outputs:
|
||||
should-run: ${{ steps.set-output.outputs.should-run }}
|
||||
steps:
|
||||
- name: Check if tag contains "cloud"
|
||||
id: set-output
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" == *cloud* ]]; then
|
||||
echo "should-run=false" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "should-run=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
build:
|
||||
needs: precheck
|
||||
if: needs.precheck.outputs.should-run == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
|
||||
@@ -38,9 +53,11 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -53,7 +70,7 @@ jobs:
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
@@ -85,9 +102,10 @@ jobs:
|
||||
retention-days: 1
|
||||
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- build
|
||||
if: needs.precheck.outputs.should-run == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
@@ -104,6 +122,11 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
@@ -127,10 +150,20 @@ jobs:
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
49
.github/workflows/helm-chart-releases.yml
vendored
Normal file
49
.github/workflows/helm-chart-releases.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
name: Release Onyx Helm Charts
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions: write-all
|
||||
|
||||
jobs:
|
||||
release:
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install Helm CLI
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
version: v3.12.1
|
||||
|
||||
- name: Add required Helm repositories
|
||||
run: |
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for chart_dir in deployment/helm/charts/*; do
|
||||
if [ -f "$chart_dir/Chart.yaml" ]; then
|
||||
echo "Building dependencies for $chart_dir"
|
||||
helm dependency build "$chart_dir"
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Publish Helm charts to gh-pages
|
||||
uses: stefanprodan/helm-gh-pages@v1.7.0
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
charts_dir: deployment/helm/charts
|
||||
branch: gh-pages
|
||||
commit_username: ${{ github.actor }}
|
||||
commit_email: ${{ github.actor }}@users.noreply.github.com
|
||||
97
.github/workflows/pr-external-dependency-unit-tests.yml
vendored
Normal file
97
.github/workflows/pr-external-dependency-unit-tests.yml
vendored
Normal file
@@ -0,0 +1,97 @@
|
||||
name: External Dependency Unit Tests
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
# AWS
|
||||
S3_AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
|
||||
S3_AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
|
||||
|
||||
# MinIO
|
||||
S3_ENDPOINT_URL: "http://localhost:9004"
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all subdirectories in backend/tests/external_dependency_unit
|
||||
dirs=$(find backend/tests/external_dependency_unit -mindepth 1 -maxdepth 1 -type d -exec basename {} \; | sort | jq -R -s -c 'split("\n")[:-1]')
|
||||
echo "test-dirs=$dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
external-dependency-unit-tests:
|
||||
needs: discover-test-dirs
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d minio relational_db cache index
|
||||
|
||||
- name: Run migrations
|
||||
run: |
|
||||
cd backend
|
||||
alembic upgrade head
|
||||
|
||||
- name: Run Tests for ${{ matrix.test-dir }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
-n 8 \
|
||||
--dist loadfile \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/external_dependency_unit/${{ matrix.test-dir }}
|
||||
149
.github/workflows/pr-helm-chart-testing.yml
vendored
149
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -53,9 +53,154 @@ jobs:
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.12.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
|
||||
run: |
|
||||
echo "=== Pre-install Cluster Status ==="
|
||||
kubectl get nodes -o wide
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get storageclass
|
||||
|
||||
- name: Add Helm repositories and update
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo update
|
||||
|
||||
- name: Pre-pull critical images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling critical images to avoid timeout ==="
|
||||
# Get kind cluster name
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
# Pre-pull images that are likely to be used
|
||||
echo "Pre-pulling PostgreSQL image..."
|
||||
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
|
||||
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
|
||||
|
||||
echo "Pre-pulling Redis image..."
|
||||
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
|
||||
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
|
||||
|
||||
echo "Pre-pulling Onyx images..."
|
||||
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
|
||||
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
|
||||
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Starting chart installation with monitoring ==="
|
||||
|
||||
# Function to monitor cluster state
|
||||
monitor_cluster() {
|
||||
while true; do
|
||||
echo "=== Cluster Status Check at $(date) ==="
|
||||
# Only show non-running pods to reduce noise
|
||||
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
|
||||
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
|
||||
echo "Non-running pods:"
|
||||
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
|
||||
else
|
||||
echo "All pods running successfully"
|
||||
fi
|
||||
# Only show recent events if there are issues
|
||||
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
|
||||
if [ -n "$RECENT_EVENTS" ]; then
|
||||
echo "Recent warnings/errors:"
|
||||
echo "$RECENT_EVENTS"
|
||||
fi
|
||||
sleep 60
|
||||
done
|
||||
}
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_cluster &
|
||||
MONITOR_PID=$!
|
||||
|
||||
# Set up cleanup
|
||||
cleanup() {
|
||||
echo "=== Cleaning up monitoring process ==="
|
||||
kill $MONITOR_PID 2>/dev/null || true
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
|
||||
}
|
||||
|
||||
# Trap cleanup on exit
|
||||
trap cleanup EXIT
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.primary.persistence.enabled=false \
|
||||
--set=redis.enabled=true \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docfetching.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
|
||||
echo "=== Installation completed successfully ==="
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
|
||||
|
||||
- name: Cleanup on failure
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
|
||||
echo "=== Pod descriptions for debugging ==="
|
||||
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
567
.github/workflows/pr-integration-tests.yml
vendored
567
.github/workflows/pr-integration-tests.yml
vendored
@@ -11,143 +11,240 @@ on:
|
||||
- "release/**"
|
||||
|
||||
env:
|
||||
# Private Registry Configuration
|
||||
PRIVATE_REGISTRY: experimental-registry.blacksmith.sh:5000
|
||||
PRIVATE_REGISTRY_USERNAME: ${{ secrets.PRIVATE_REGISTRY_USERNAME }}
|
||||
PRIVATE_REGISTRY_PASSWORD: ${{ secrets.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
discover-test-dirs:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
all_dirs=""
|
||||
for dir in $tests_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
|
||||
done
|
||||
for dir in $connector_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
|
||||
done
|
||||
|
||||
# Remove trailing comma and wrap in array
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
prepare-build:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Download OpenAPI artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push integration test Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
[
|
||||
discover-test-dirs,
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
# Pull all images from registry in parallel
|
||||
echo "Pulling Docker images in parallel..."
|
||||
# Pull images from private registry
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
# Wait for all background jobs to complete
|
||||
wait
|
||||
echo "All Docker images pulled successfully"
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# Start containers for multi-tenant tests
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Waiting for 3 minutes to ensure API server is ready..."
|
||||
sleep 180
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e IMAGE_TAG=test \
|
||||
-e DEV_MODE=true \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
id: run_multitenant_tests
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
|
||||
echo "Multi-tenant integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All multi-tenant integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
# Re-tag to remove registry prefix for docker-compose
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
@@ -160,7 +257,16 @@ jobs:
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
@@ -203,43 +309,44 @@ jobs:
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
@@ -259,7 +366,7 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
@@ -268,3 +375,157 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
wait
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
|
||||
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
- name: Wait for service to be ready (multi-tenant)
|
||||
run: |
|
||||
echo "Starting wait-for-service script for multi-tenant..."
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error; retrying..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Running multi-tenant integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
-e SKIP_RESET=true \
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e IMAGE_TAG=test \
|
||||
-e DEV_MODE=true \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
|
||||
- name: Dump API server logs (multi-tenant)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server_multitenant.log || true
|
||||
|
||||
- name: Dump all-container logs (multi-tenant)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose-multitenant.log || true
|
||||
|
||||
- name: Upload logs (multi-tenant)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs-multitenant
|
||||
path: ${{ github.workspace }}/docker-compose-multitenant.log
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
|
||||
required:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
needs: [integration-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const needs = ${{ toJSON(needs) }};
|
||||
const failed = Object.values(needs).some(n => n.result !== 'success');
|
||||
if (failed) {
|
||||
core.setFailed('One or more upstream jobs failed or were cancelled.');
|
||||
} else {
|
||||
core.notice('All required jobs succeeded.');
|
||||
}
|
||||
|
||||
38
.github/workflows/pr-labeler.yml
vendored
Normal file
38
.github/workflows/pr-labeler.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: PR Labeler
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
- opened
|
||||
- reopened
|
||||
- synchronize
|
||||
- edited
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
validate_pr_title:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check PR title for Conventional Commits
|
||||
env:
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
run: |
|
||||
echo "PR Title: $PR_TITLE"
|
||||
if [[ ! "$PR_TITLE" =~ ^(feat|fix|docs|test|ci|refactor|perf|chore|revert|build)(\(.+\))?:\ .+ ]]; then
|
||||
echo "::error::❌ Your PR title does not follow the Conventional Commits format.
|
||||
This check ensures that all pull requests use clear, consistent titles that help automate changelogs and improve project history.
|
||||
|
||||
Please update your PR title to follow the Conventional Commits style.
|
||||
Here is a link to a blog explaining the reason why we've included the Conventional Commits style into our PR titles: https://xfuture-blog.com/working-with-conventional-commits
|
||||
|
||||
**Here are some examples of valid PR titles:**
|
||||
- feat: add user authentication
|
||||
- fix(login): handle null password error
|
||||
- docs(readme): update installation instructions"
|
||||
exit 1
|
||||
fi
|
||||
378
.github/workflows/pr-mit-integration-tests.yml
vendored
378
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -5,90 +5,244 @@ concurrency:
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
# Private Registry Configuration
|
||||
PRIVATE_REGISTRY: experimental-registry.blacksmith.sh:5000
|
||||
PRIVATE_REGISTRY_USERNAME: ${{ secrets.PRIVATE_REGISTRY_USERNAME }}
|
||||
PRIVATE_REGISTRY_PASSWORD: ${{ secrets.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
jobs:
|
||||
integration-tests-mit:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
discover-test-dirs:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
all_dirs=""
|
||||
for dir in $tests_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
|
||||
done
|
||||
for dir in $connector_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
|
||||
done
|
||||
|
||||
# Remove trailing comma and wrap in array
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
prepare-build:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Download OpenAPI artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push integration test Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
[
|
||||
discover-test-dirs,
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
# See https://docs.blacksmith.sh/blacksmith-runners/overview
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
# Pull all images from registry in parallel
|
||||
echo "Pulling Docker images in parallel..."
|
||||
# Pull images from private registry
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
# Wait for all background jobs to complete
|
||||
wait
|
||||
echo "All Docker images pulled successfully"
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
# Re-tag to remove registry prefix for docker-compose
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
@@ -99,7 +253,16 @@ jobs:
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
@@ -143,42 +306,44 @@ jobs:
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
@@ -198,7 +363,7 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
@@ -207,3 +372,20 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
required:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
needs: [integration-tests-mit]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const needs = ${{ toJSON(needs) }};
|
||||
const failed = Object.values(needs).some(n => n.result !== 'success');
|
||||
if (failed) {
|
||||
core.setFailed('One or more upstream jobs failed or were cancelled.');
|
||||
} else {
|
||||
core.notice('All required jobs succeeded.');
|
||||
}
|
||||
|
||||
234
.github/workflows/pr-playwright-tests.yml
vendored
234
.github/workflows/pr-playwright-tests.yml
vendored
@@ -6,44 +6,165 @@ concurrency:
|
||||
on: push
|
||||
|
||||
env:
|
||||
# AWS ECR Configuration
|
||||
AWS_REGION: ${{ secrets.AWS_REGION || 'us-west-2' }}
|
||||
ECR_REGISTRY: ${{ secrets.ECR_REGISTRY }}
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_ECR }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_ECR }}
|
||||
BUILDX_NO_DEFAULT_ATTESTATIONS: 1
|
||||
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
|
||||
# for federated slack tests
|
||||
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
|
||||
SLACK_CLIENT_SECRET: ${{ secrets.SLACK_CLIENT_SECRET }}
|
||||
|
||||
MOCK_LLM_RESPONSE: true
|
||||
PYTEST_PLAYWRIGHT_SKIP_INITIAL_RESET: true
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
name: Playwright Tests
|
||||
build-web-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Web Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
|
||||
playwright-tests:
|
||||
needs: [build-web-image, build-backend-image, build-model-server-image]
|
||||
name: Playwright Tests
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
# Pull all images from ECR in parallel
|
||||
echo "Pulling Docker images in parallel..."
|
||||
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }}) &
|
||||
|
||||
# Wait for all background jobs to complete
|
||||
wait
|
||||
echo "All Docker images pulled successfully"
|
||||
|
||||
# Re-tag with expected names for docker-compose
|
||||
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }} onyxdotapp/onyx-web-server:test
|
||||
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
@@ -58,68 +179,13 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
|
||||
- name: Build Web Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-web-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
GEN_AI_API_KEY=${{ secrets.OPENAI_API_KEY }} \
|
||||
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }} \
|
||||
EXA_API_KEY=${{ env.EXA_API_KEY }} \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
@@ -160,12 +226,6 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run pytest playwright test init
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTEST_IGNORE_SKIP: true
|
||||
run: pytest -s tests/integration/tests/playwright/test_playwright.py
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./web
|
||||
run: npx playwright test
|
||||
|
||||
25
.github/workflows/pr-python-checks.yml
vendored
25
.github/workflows/pr-python-checks.yml
vendored
@@ -31,16 +31,31 @@ jobs:
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Run MyPy
|
||||
run: |
|
||||
cd backend
|
||||
mypy .
|
||||
|
||||
- name: Run ruff
|
||||
run: |
|
||||
cd backend
|
||||
ruff .
|
||||
|
||||
- name: Check import order with reorder-python-imports
|
||||
run: |
|
||||
cd backend
|
||||
|
||||
17
.github/workflows/pr-python-connector-tests.yml
vendored
17
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -16,12 +16,13 @@ env:
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
# Jira
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
|
||||
@@ -49,6 +50,15 @@ env:
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
|
||||
# Hubspot
|
||||
HUBSPOT_ACCESS_TOKEN: ${{ secrets.HUBSPOT_ACCESS_TOKEN }}
|
||||
|
||||
# IMAP
|
||||
IMAP_HOST: ${{ secrets.IMAP_HOST }}
|
||||
IMAP_USERNAME: ${{ secrets.IMAP_USERNAME }}
|
||||
IMAP_PASSWORD: ${{ secrets.IMAP_PASSWORD }}
|
||||
IMAP_MAILBOXES: ${{ secrets.IMAP_MAILBOXES }}
|
||||
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
@@ -81,6 +91,11 @@ env:
|
||||
# Slack
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
|
||||
# Teams
|
||||
TEAMS_APPLICATION_ID: ${{ secrets.TEAMS_APPLICATION_ID }}
|
||||
TEAMS_DIRECTORY_ID: ${{ secrets.TEAMS_DIRECTORY_ID }}
|
||||
TEAMS_SECRET: ${{ secrets.TEAMS_SECRET }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
3
.github/workflows/pr-python-tests.yml
vendored
3
.github/workflows/pr-python-tests.yml
vendored
@@ -15,6 +15,9 @@ jobs:
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
15
.gitignore
vendored
15
.gitignore
vendored
@@ -14,12 +14,27 @@
|
||||
/web/test-results/
|
||||
backend/onyx/agent_search/main/test_data.json
|
||||
backend/tests/regression/answer_quality/test_data.json
|
||||
backend/tests/regression/search_quality/eval-*
|
||||
backend/tests/regression/search_quality/search_eval_config.yaml
|
||||
backend/tests/regression/search_quality/*.json
|
||||
*.log
|
||||
|
||||
# secret files
|
||||
.env
|
||||
jira_test_env
|
||||
settings.json
|
||||
|
||||
# others
|
||||
/deployment/data/nginx/app.conf
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
|
||||
# Local .terraform directories
|
||||
**/.terraform/*
|
||||
|
||||
# Local .tfstate files
|
||||
*.tfstate
|
||||
*.tfstate.*
|
||||
|
||||
# Local .terraform.lock.hcl file
|
||||
.terraform.lock.hcl
|
||||
|
||||
22
.vscode/env_template.txt
vendored
22
.vscode/env_template.txt
vendored
@@ -23,6 +23,9 @@ DISABLE_LLM_DOC_RELEVANCE=False
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
OPENID_CONFIG_URL=<REPLACE THIS>
|
||||
SAML_CONF_DIR=/<ABSOLUTE PATH TO ONYX>/onyx/backend/ee/onyx/configs/saml_config
|
||||
|
||||
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
|
||||
REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
@@ -45,8 +48,8 @@ PYTHONPATH=../backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY=<REPLACE THIS>
|
||||
# Internet Search
|
||||
EXA_API_KEY=<REPLACE THIS>
|
||||
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
@@ -58,3 +61,18 @@ AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ran
|
||||
AGENT_RERANKING_STATS=True
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
|
||||
# S3 File Store Configuration (MinIO for local development)
|
||||
S3_ENDPOINT_URL=http://localhost:9004
|
||||
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
|
||||
S3_AWS_ACCESS_KEY_ID=minioadmin
|
||||
S3_AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
|
||||
# Show extra/uncommon connectors
|
||||
SHOW_EXTRA_CONNECTORS=True
|
||||
|
||||
# Local langsmith tracing
|
||||
LANGSMITH_TRACING="true"
|
||||
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
|
||||
LANGSMITH_API_KEY=<REPLACE_THIS>
|
||||
LANGSMITH_PROJECT=<REPLACE_THIS>
|
||||
155
.vscode/launch.template.jsonc
vendored
155
.vscode/launch.template.jsonc
vendored
@@ -24,21 +24,23 @@
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery user files indexing",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": ["Web Server", "Model Server", "API Server"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
@@ -46,14 +48,15 @@
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery user files indexing",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
"stopAll": true
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
@@ -189,7 +192,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert"
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
@@ -226,35 +229,66 @@
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"name": "Celery docfetching",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing"
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching,user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
"consoleTitle": "Celery docfetching Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery docprocessing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docprocessing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docprocessing@%n",
|
||||
"-Q",
|
||||
"docprocessing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docprocessing Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
@@ -303,35 +337,6 @@
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user files indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_files_indexing@%n",
|
||||
"-Q",
|
||||
"user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user files indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
@@ -412,6 +417,46 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
// script to generate the openapi schema
|
||||
"name": "Onyx OpenAPI Schema Generator",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/onyx_openapi_schema.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"--filename",
|
||||
"generated/openapi.json"
|
||||
]
|
||||
},
|
||||
{
|
||||
// script to debug multi tenant db issues
|
||||
"name": "Onyx DB Manager (Top Chunks)",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/debugging/onyx_db.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"--password",
|
||||
"your_password_here",
|
||||
"--port",
|
||||
"5433",
|
||||
"--report",
|
||||
"top-chunks",
|
||||
"--filename",
|
||||
"generated/tenants_by_num_docs.csv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Debug React Web App in Chrome",
|
||||
"type": "chrome",
|
||||
|
||||
101
.vscode/tasks.template.jsonc
vendored
Normal file
101
.vscode/tasks.template.jsonc
vendored
Normal file
@@ -0,0 +1,101 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "austin",
|
||||
"label": "Profile celery beat",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/backend"
|
||||
},
|
||||
"command": [
|
||||
"sudo",
|
||||
"-E"
|
||||
],
|
||||
"args": [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "Generate Onyx OpenAPI Python client",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/backend"
|
||||
},
|
||||
"command": [
|
||||
"openapi-generator"
|
||||
],
|
||||
"args": [
|
||||
"generate",
|
||||
"-i",
|
||||
"generated/openapi.json",
|
||||
"-g",
|
||||
"python",
|
||||
"-o",
|
||||
"generated/onyx_openapi_client",
|
||||
"--package-name",
|
||||
"onyx_openapi_client",
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "Generate Typescript Fetch client (openapi-generator)",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"command": [
|
||||
"openapi-generator"
|
||||
],
|
||||
"args": [
|
||||
"generate",
|
||||
"-i",
|
||||
"backend/generated/openapi.json",
|
||||
"-g",
|
||||
"typescript-fetch",
|
||||
"-o",
|
||||
"${workspaceFolder}/web/src/lib/generated/onyx_api",
|
||||
"--additional-properties=disallowAdditionalPropertiesIfNotPresent=false,legacyDiscriminatorBehavior=false,supportsES6=true",
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "Generate TypeScript Client (openapi-ts)",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/web"
|
||||
},
|
||||
"command": [
|
||||
"npx"
|
||||
],
|
||||
"args": [
|
||||
"openapi-typescript",
|
||||
"../backend/generated/openapi.json",
|
||||
"--output",
|
||||
"./src/lib/generated/onyx-schema.ts",
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "Generate TypeScript Client (orval)",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/web"
|
||||
},
|
||||
"command": [
|
||||
"npx"
|
||||
],
|
||||
"args": [
|
||||
"orval",
|
||||
"--config",
|
||||
"orval.config.js",
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
295
AGENTS.md
Normal file
295
AGENTS.md
Normal file
@@ -0,0 +1,295 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to Codex when working with code in this repository.
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `workon onyx &&` in front
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
- To connect to the Postgres database, use: `docker exec -it onyx-stack-relational_db-1 psql -U postgres -c "<SQL>"`
|
||||
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
|
||||
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
|
||||
outside of those directories.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**Onyx** (formerly Danswer) is an open-source Gen-AI and Enterprise Search platform that connects to company documents, apps, and people. It features a modular architecture with both Community Edition (MIT licensed) and Enterprise Edition offerings.
|
||||
|
||||
|
||||
### Background Workers (Celery)
|
||||
|
||||
Onyx uses Celery for asynchronous task processing with multiple specialized workers:
|
||||
|
||||
#### Worker Types
|
||||
|
||||
1. **Primary Worker** (`celery_app.py`)
|
||||
- Coordinates core background tasks and system-wide operations
|
||||
- Handles connector management, document sync, pruning, and periodic checks
|
||||
- Runs with 4 threads concurrency
|
||||
- Tasks: connector deletion, vespa sync, pruning, LLM model updates, user file sync
|
||||
|
||||
2. **Docfetching Worker** (`docfetching`)
|
||||
- Fetches documents from external data sources (connectors)
|
||||
- Spawns docprocessing tasks for each document batch
|
||||
- Implements watchdog monitoring for stuck connectors
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
3. **Docprocessing Worker** (`docprocessing`)
|
||||
- Processes fetched documents through the indexing pipeline:
|
||||
- Upserts documents to PostgreSQL
|
||||
- Chunks documents and adds contextual information
|
||||
- Embeds chunks via model server
|
||||
- Writes chunks to Vespa vector database
|
||||
- Updates document metadata
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Primary task: document pruning operations
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
- Handles Knowledge Graph processing and clustering
|
||||
- Builds relationships between documents
|
||||
- Runs clustering algorithms
|
||||
- Configurable concurrency
|
||||
|
||||
7. **Monitoring Worker** (`monitoring`)
|
||||
- System health monitoring and metrics collection
|
||||
- Monitors Celery queues, process memory, and system status
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
- Indexing checks (every 15 seconds)
|
||||
- Connector deletion checks (every 20 seconds)
|
||||
- Vespa sync checks (every 20 seconds)
|
||||
- Pruning checks (every 20 seconds)
|
||||
- KG processing (every 60 seconds)
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
via Celery Beat.
|
||||
- **Task Prioritization**: High, Medium, Low priority queues
|
||||
- **Monitoring**: Built-in heartbeat and liveness checking
|
||||
- **Failure Handling**: Automatic retry and failure recovery mechanisms
|
||||
- **Redis Coordination**: Inter-process communication via Redis
|
||||
- **PostgreSQL State**: Task state and metadata stored in PostgreSQL
|
||||
|
||||
|
||||
#### Important Notes
|
||||
|
||||
**Defining Tasks**:
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
function.
|
||||
|
||||
**Testing Updates**:
|
||||
If you make any updates to a celery worker and you want to test these changes, you will need
|
||||
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
|
||||
|
||||
### Code Quality
|
||||
```bash
|
||||
# Install and run pre-commit hooks
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
NOTE: Always make sure everything is strictly typed (both in Python and Typescript).
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Technology Stack
|
||||
- **Backend**: Python 3.11, FastAPI, SQLAlchemy, Alembic, Celery
|
||||
- **Frontend**: Next.js 15+, React 18, TypeScript, Tailwind CSS
|
||||
- **Database**: PostgreSQL with Redis caching
|
||||
- **Search**: Vespa vector database
|
||||
- **Auth**: OAuth2, SAML, multi-provider support
|
||||
- **AI/ML**: LangChain, LiteLLM, multiple embedding models
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
backend/
|
||||
├── onyx/
|
||||
│ ├── auth/ # Authentication & authorization
|
||||
│ ├── chat/ # Chat functionality & LLM interactions
|
||||
│ ├── connectors/ # Data source connectors
|
||||
│ ├── db/ # Database models & operations
|
||||
│ ├── document_index/ # Vespa integration
|
||||
│ ├── federated_connectors/ # External search connectors
|
||||
│ ├── llm/ # LLM provider integrations
|
||||
│ └── server/ # API endpoints & routers
|
||||
├── ee/ # Enterprise Edition features
|
||||
├── alembic/ # Database migrations
|
||||
└── tests/ # Test suites
|
||||
|
||||
web/
|
||||
├── src/app/ # Next.js app router pages
|
||||
├── src/components/ # Reusable React components
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
```bash
|
||||
# Standard migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Multi-tenant (Enterprise)
|
||||
alembic -n schema_private upgrade head
|
||||
```
|
||||
|
||||
### Creating Migrations
|
||||
```bash
|
||||
# Auto-generate migration
|
||||
alembic revision --autogenerate -m "description"
|
||||
|
||||
# Multi-tenant migration
|
||||
alembic -n schema_private revision --autogenerate -m "description"
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
There are 4 main types of tests within Onyx:
|
||||
|
||||
### Unit Tests
|
||||
These should not assume any Onyx/external services are available to be called.
|
||||
Interactions with the outside world should be mocked using `unittest.mock`. Generally, only
|
||||
write these for complex, isolated modules e.g. `citation_processing.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests
|
||||
These tests assume that all external dependencies of Onyx are available and callable (e.g. Postgres, Redis,
|
||||
MinIO/S3, Vespa are running + OpenAI can be called + any request to the internet is fine + etc.).
|
||||
|
||||
However, the actual Onyx containers are not running and with these tests we call the function to test directly.
|
||||
We can also mock components/calls at will.
|
||||
|
||||
The goal with these tests are to minimize mocking while giving some flexibility to mock things that are flakey,
|
||||
need strictly controlled behavior, or need to have their internal behavior validated (e.g. verify a function is called
|
||||
with certain args, something that would be impossible with proper integration tests).
|
||||
|
||||
A great example of this type of test is `backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
Standard integration tests. Every test in `backend/tests/integration` runs against a real Onyx deployment. We cannot
|
||||
mock anything in these tests. Prefer writing integration tests (or External Dependency Unit Tests if mocking/internal
|
||||
verification is necessary) over any other type of test.
|
||||
|
||||
Tests are parallelized at a directory level.
|
||||
|
||||
When writing integration tests, make sure to check the root `conftest.py` for useful fixtures + the `backend/tests/integration/common_utils` directory for utilities. Prefer (if one exists), calling the appropriate Manager
|
||||
class in the utils over directly calling the APIs with a library like `requests`. Prefer using fixtures rather than
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright (E2E) Tests
|
||||
These tests are an even more complete version of the Integration Tests mentioned above. Has all services of Onyx
|
||||
running, *including* the Web Server.
|
||||
|
||||
Use these tests for anything that requires significant frontend <-> backend coordination.
|
||||
|
||||
Tests are located at `web/tests/e2e`. Tests are written in TypeScript.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
to logs via the `backend/log/<service_name>_debug.log` file. All Onyx services (api_server, web_server, celery_X)
|
||||
will be tailing their logs to this file.
|
||||
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Never commit API keys or secrets to repository
|
||||
- Use encrypted credential storage for connector credentials
|
||||
- Follow RBAC patterns for new features
|
||||
- Implement proper input validation with Pydantic models
|
||||
- Use parameterized queries to prevent SQL injection
|
||||
|
||||
## AI/LLM Integration
|
||||
|
||||
- Multiple LLM providers supported via LiteLLM
|
||||
- Configurable models per feature (chat, search, embeddings)
|
||||
- Streaming support for real-time responses
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
**Issues to Address**
|
||||
What the change is meant to do.
|
||||
|
||||
**Important Notes**
|
||||
Things you come across in your research that are important to the implementation.
|
||||
|
||||
**Implementation strategy**
|
||||
How you are going to make the changes happen. High level approach.
|
||||
|
||||
**Tests**
|
||||
What unit (use rarely), external dependency unit, integration, and playwright tests you plan to write to
|
||||
verify the correct behavior. Don't overtest. Usually, a given change only needs one type of test.
|
||||
|
||||
Do NOT include these: *Timeline*, *Rollback plan*
|
||||
|
||||
This is a minimal list - feel free to include more. Do NOT write code as part of your plan.
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
295
CLAUDE.md
Normal file
295
CLAUDE.md
Normal file
@@ -0,0 +1,295 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `workon onyx &&` in front
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
- To connect to the Postgres database, use: `docker exec -it onyx-stack-relational_db-1 psql -U postgres -c "<SQL>"`
|
||||
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
|
||||
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
|
||||
outside of those directories.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**Onyx** (formerly Danswer) is an open-source Gen-AI and Enterprise Search platform that connects to company documents, apps, and people. It features a modular architecture with both Community Edition (MIT licensed) and Enterprise Edition offerings.
|
||||
|
||||
|
||||
### Background Workers (Celery)
|
||||
|
||||
Onyx uses Celery for asynchronous task processing with multiple specialized workers:
|
||||
|
||||
#### Worker Types
|
||||
|
||||
1. **Primary Worker** (`celery_app.py`)
|
||||
- Coordinates core background tasks and system-wide operations
|
||||
- Handles connector management, document sync, pruning, and periodic checks
|
||||
- Runs with 4 threads concurrency
|
||||
- Tasks: connector deletion, vespa sync, pruning, LLM model updates, user file sync
|
||||
|
||||
2. **Docfetching Worker** (`docfetching`)
|
||||
- Fetches documents from external data sources (connectors)
|
||||
- Spawns docprocessing tasks for each document batch
|
||||
- Implements watchdog monitoring for stuck connectors
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
3. **Docprocessing Worker** (`docprocessing`)
|
||||
- Processes fetched documents through the indexing pipeline:
|
||||
- Upserts documents to PostgreSQL
|
||||
- Chunks documents and adds contextual information
|
||||
- Embeds chunks via model server
|
||||
- Writes chunks to Vespa vector database
|
||||
- Updates document metadata
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Primary task: document pruning operations
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
- Handles Knowledge Graph processing and clustering
|
||||
- Builds relationships between documents
|
||||
- Runs clustering algorithms
|
||||
- Configurable concurrency
|
||||
|
||||
7. **Monitoring Worker** (`monitoring`)
|
||||
- System health monitoring and metrics collection
|
||||
- Monitors Celery queues, process memory, and system status
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
- Indexing checks (every 15 seconds)
|
||||
- Connector deletion checks (every 20 seconds)
|
||||
- Vespa sync checks (every 20 seconds)
|
||||
- Pruning checks (every 20 seconds)
|
||||
- KG processing (every 60 seconds)
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
via Celery Beat.
|
||||
- **Task Prioritization**: High, Medium, Low priority queues
|
||||
- **Monitoring**: Built-in heartbeat and liveness checking
|
||||
- **Failure Handling**: Automatic retry and failure recovery mechanisms
|
||||
- **Redis Coordination**: Inter-process communication via Redis
|
||||
- **PostgreSQL State**: Task state and metadata stored in PostgreSQL
|
||||
|
||||
|
||||
#### Important Notes
|
||||
|
||||
**Defining Tasks**:
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
function.
|
||||
|
||||
**Testing Updates**:
|
||||
If you make any updates to a celery worker and you want to test these changes, you will need
|
||||
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
|
||||
|
||||
### Code Quality
|
||||
```bash
|
||||
# Install and run pre-commit hooks
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
NOTE: Always make sure everything is strictly typed (both in Python and Typescript).
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Technology Stack
|
||||
- **Backend**: Python 3.11, FastAPI, SQLAlchemy, Alembic, Celery
|
||||
- **Frontend**: Next.js 15+, React 18, TypeScript, Tailwind CSS
|
||||
- **Database**: PostgreSQL with Redis caching
|
||||
- **Search**: Vespa vector database
|
||||
- **Auth**: OAuth2, SAML, multi-provider support
|
||||
- **AI/ML**: LangChain, LiteLLM, multiple embedding models
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
backend/
|
||||
├── onyx/
|
||||
│ ├── auth/ # Authentication & authorization
|
||||
│ ├── chat/ # Chat functionality & LLM interactions
|
||||
│ ├── connectors/ # Data source connectors
|
||||
│ ├── db/ # Database models & operations
|
||||
│ ├── document_index/ # Vespa integration
|
||||
│ ├── federated_connectors/ # External search connectors
|
||||
│ ├── llm/ # LLM provider integrations
|
||||
│ └── server/ # API endpoints & routers
|
||||
├── ee/ # Enterprise Edition features
|
||||
├── alembic/ # Database migrations
|
||||
└── tests/ # Test suites
|
||||
|
||||
web/
|
||||
├── src/app/ # Next.js app router pages
|
||||
├── src/components/ # Reusable React components
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
```bash
|
||||
# Standard migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Multi-tenant (Enterprise)
|
||||
alembic -n schema_private upgrade head
|
||||
```
|
||||
|
||||
### Creating Migrations
|
||||
```bash
|
||||
# Auto-generate migration
|
||||
alembic revision --autogenerate -m "description"
|
||||
|
||||
# Multi-tenant migration
|
||||
alembic -n schema_private revision --autogenerate -m "description"
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
There are 4 main types of tests within Onyx:
|
||||
|
||||
### Unit Tests
|
||||
These should not assume any Onyx/external services are available to be called.
|
||||
Interactions with the outside world should be mocked using `unittest.mock`. Generally, only
|
||||
write these for complex, isolated modules e.g. `citation_processing.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests
|
||||
These tests assume that all external dependencies of Onyx are available and callable (e.g. Postgres, Redis,
|
||||
MinIO/S3, Vespa are running + OpenAI can be called + any request to the internet is fine + etc.).
|
||||
|
||||
However, the actual Onyx containers are not running and with these tests we call the function to test directly.
|
||||
We can also mock components/calls at will.
|
||||
|
||||
The goal with these tests are to minimize mocking while giving some flexibility to mock things that are flakey,
|
||||
need strictly controlled behavior, or need to have their internal behavior validated (e.g. verify a function is called
|
||||
with certain args, something that would be impossible with proper integration tests).
|
||||
|
||||
A great example of this type of test is `backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
Standard integration tests. Every test in `backend/tests/integration` runs against a real Onyx deployment. We cannot
|
||||
mock anything in these tests. Prefer writing integration tests (or External Dependency Unit Tests if mocking/internal
|
||||
verification is necessary) over any other type of test.
|
||||
|
||||
Tests are parallelized at a directory level.
|
||||
|
||||
When writing integration tests, make sure to check the root `conftest.py` for useful fixtures + the `backend/tests/integration/common_utils` directory for utilities. Prefer (if one exists), calling the appropriate Manager
|
||||
class in the utils over directly calling the APIs with a library like `requests`. Prefer using fixtures rather than
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright (E2E) Tests
|
||||
These tests are an even more complete version of the Integration Tests mentioned above. Has all services of Onyx
|
||||
running, *including* the Web Server.
|
||||
|
||||
Use these tests for anything that requires significant frontend <-> backend coordination.
|
||||
|
||||
Tests are located at `web/tests/e2e`. Tests are written in TypeScript.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
to logs via the `backend/log/<service_name>_debug.log` file. All Onyx services (api_server, web_server, celery_X)
|
||||
will be tailing their logs to this file.
|
||||
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Never commit API keys or secrets to repository
|
||||
- Use encrypted credential storage for connector credentials
|
||||
- Follow RBAC patterns for new features
|
||||
- Implement proper input validation with Pydantic models
|
||||
- Use parameterized queries to prevent SQL injection
|
||||
|
||||
## AI/LLM Integration
|
||||
|
||||
- Multiple LLM providers supported via LiteLLM
|
||||
- Configurable models per feature (chat, search, embeddings)
|
||||
- Streaming support for real-time responses
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
**Issues to Address**
|
||||
What the change is meant to do.
|
||||
|
||||
**Important Notes**
|
||||
Things you come across in your research that are important to the implementation.
|
||||
|
||||
**Implementation strategy**
|
||||
How you are going to make the changes happen. High level approach.
|
||||
|
||||
**Tests**
|
||||
What unit (use rarely), external dependency unit, integration, and playwright tests you plan to write to
|
||||
verify the correct behavior. Don't overtest. Usually, a given change only needs one type of test.
|
||||
|
||||
Do NOT include these: *Timeline*, *Rollback plan*
|
||||
|
||||
This is a minimal list - feel free to include more. Do NOT write code as part of your plan.
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
@@ -12,8 +12,8 @@ As an open source project in a rapidly changing space, we welcome all contributi
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to Hagen (or any other maintainer) on the Onyx team
|
||||
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
|
||||
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
@@ -28,7 +28,7 @@ Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
@@ -59,6 +59,7 @@ Onyx being a fully functional app, relies on some external software, specificall
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [MinIO](https://min.io/) (File Store)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
> **Note:**
|
||||
@@ -102,10 +103,10 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r onyx/backend/requirements/default.txt
|
||||
pip install -r onyx/backend/requirements/dev.txt
|
||||
pip install -r onyx/backend/requirements/ee.txt
|
||||
pip install -r onyx/backend/requirements/model_server.txt
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/ee.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
@@ -171,10 +172,10 @@ Otherwise, you can follow the instructions below to run the application for deve
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
@@ -5,7 +5,7 @@ This guide explains how to set up and use VSCode's debugging capabilities with t
|
||||
## Initial Setup
|
||||
|
||||
1. **Environment Setup**:
|
||||
- Copy `.vscode/.env.template` to `.vscode/.env`
|
||||
- Copy `.vscode/env_template.txt` to `.vscode/.env`
|
||||
- Fill in the necessary environment variables in `.vscode/.env`
|
||||
2. **launch.json**:
|
||||
- Copy `.vscode/launch.template.jsonc` to `.vscode/launch.json`
|
||||
@@ -17,10 +17,9 @@ Before starting, make sure the Docker Daemon is running.
|
||||
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
|
||||
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
|
||||
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
|
||||
4. CD into web, run "npm i" followed by npm run dev.
|
||||
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
7. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
## Features
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95
|
||||
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
|
||||
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/deployment/getting_started/quickstart) to learn more.
|
||||
|
||||
We also have built-in support for high-availability/scalable deployment on Kubernetes.
|
||||
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
|
||||
@@ -97,7 +97,7 @@ Keep knowledge and access up to sync across 40+ connectors:
|
||||
- Websites
|
||||
- And more ...
|
||||
|
||||
See the full list [here](https://docs.onyx.app/connectors).
|
||||
See the full list [here](https://docs.onyx.app/admin/connectors/overview).
|
||||
|
||||
|
||||
## 📚 Licensing
|
||||
|
||||
2
backend/.gitignore
vendored
2
backend/.gitignore
vendored
@@ -11,4 +11,4 @@ dynamic_config_storage/
|
||||
celerybeat-schedule*
|
||||
onyx/connectors/salesforce/data/
|
||||
.test.env
|
||||
|
||||
/generated
|
||||
|
||||
@@ -12,7 +12,8 @@ ARG ONYX_VERSION=0.0.0-dev
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true"
|
||||
DO_NOT_TRACK="true" \
|
||||
PLAYWRIGHT_BROWSERS_PATH="/app/.cache/ms-playwright"
|
||||
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
@@ -77,6 +78,9 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Install postgresql-client for easy manual tests
|
||||
# Install it here to avoid it being cleaned up above
|
||||
RUN apt-get update && apt-get install -y postgresql-client
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
@@ -85,7 +89,7 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('punkt', quiet=True);"
|
||||
nltk.download('punkt_tab', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Set up application files
|
||||
@@ -113,6 +117,14 @@ COPY ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
chown -R onyx:onyx /app && \
|
||||
mkdir -p /var/log/onyx && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -9,11 +9,36 @@ visit https://github.com/onyx-dot-app/onyx."
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
HF_HOME=/app/.cache/huggingface
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN mkdir -p /app && \
|
||||
groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
chown -R onyx:onyx /app && \
|
||||
mkdir -p /var/log/onyx && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
# --- add toolchain needed for Rust/Python builds (fastuuid) ---
|
||||
ENV RUSTUP_HOME=/usr/local/rustup \
|
||||
CARGO_HOME=/usr/local/cargo \
|
||||
PATH=/usr/local/cargo/bin:$PATH
|
||||
|
||||
RUN set -eux; \
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
curl \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
# Install latest stable Rust (supports Cargo.lock v4)
|
||||
&& curl -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal --default-toolchain stable \
|
||||
&& rustc --version && cargo --version
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
@@ -38,9 +63,11 @@ snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
|
||||
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, don't overwrite it with the built in cache folder
|
||||
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
|
||||
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
|
||||
# it's preserved in order to combine with the user's cache contents
|
||||
RUN mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
|
||||
chown -R onyx:onyx /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -20,3 +20,44 @@ To run all un-applied migrations:
|
||||
To undo migrations:
|
||||
`alembic downgrade -X`
|
||||
where X is the number of migrations you want to undo from the current state
|
||||
|
||||
### Multi-tenant migrations
|
||||
|
||||
For multi-tenant deployments, you can use additional options:
|
||||
|
||||
**Upgrade all tenants:**
|
||||
```bash
|
||||
alembic -x upgrade_all_tenants=true upgrade head
|
||||
```
|
||||
|
||||
**Upgrade specific schemas:**
|
||||
```bash
|
||||
# Single schema
|
||||
alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012 upgrade head
|
||||
|
||||
# Multiple schemas (comma-separated)
|
||||
alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012,public,another_tenant upgrade head
|
||||
```
|
||||
|
||||
**Upgrade tenants within an alphabetical range:**
|
||||
```bash
|
||||
# Upgrade tenants 100-200 when sorted alphabetically (positions 100 to 200)
|
||||
alembic -x upgrade_all_tenants=true -x tenant_range_start=100 -x tenant_range_end=200 upgrade head
|
||||
|
||||
# Upgrade tenants starting from position 1000 alphabetically
|
||||
alembic -x upgrade_all_tenants=true -x tenant_range_start=1000 upgrade head
|
||||
|
||||
# Upgrade first 500 tenants alphabetically
|
||||
alembic -x upgrade_all_tenants=true -x tenant_range_end=500 upgrade head
|
||||
```
|
||||
|
||||
**Continue on error (for batch operations):**
|
||||
```bash
|
||||
alembic -x upgrade_all_tenants=true -x continue=true upgrade head
|
||||
```
|
||||
|
||||
The tenant range filtering works by:
|
||||
1. Sorting tenant IDs alphabetically
|
||||
2. Using 1-based position numbers (1st, 2nd, 3rd tenant, etc.)
|
||||
3. Filtering to the specified range of positions
|
||||
4. Non-tenant schemas (like 'public') are always included
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Any, Literal
|
||||
from onyx.db.engine import get_iam_auth_token
|
||||
from onyx.db.engine.iam_auth import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import text
|
||||
@@ -21,10 +21,14 @@ from alembic import context
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import (
|
||||
MULTI_TENANT,
|
||||
POSTGRES_DEFAULT_SCHEMA,
|
||||
TENANT_ID_PREFIX,
|
||||
)
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
|
||||
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
|
||||
# hidden! (defaults to level=WARN)
|
||||
@@ -69,15 +73,67 @@ def include_object(
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool, bool]:
|
||||
def filter_tenants_by_range(
|
||||
tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Filter tenant IDs by alphabetical position range.
|
||||
|
||||
Args:
|
||||
tenant_ids: List of tenant IDs to filter
|
||||
start_range: Starting position in alphabetically sorted list (1-based, inclusive)
|
||||
end_range: Ending position in alphabetically sorted list (1-based, inclusive)
|
||||
|
||||
Returns:
|
||||
Filtered list of tenant IDs in their original order
|
||||
"""
|
||||
if start_range is None and end_range is None:
|
||||
return tenant_ids
|
||||
|
||||
# Separate tenant IDs from non-tenant schemas
|
||||
tenant_schemas = [tid for tid in tenant_ids if tid.startswith(TENANT_ID_PREFIX)]
|
||||
non_tenant_schemas = [
|
||||
tid for tid in tenant_ids if not tid.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
# Sort tenant schemas alphabetically.
|
||||
# NOTE: can cause missed schemas if a schema is created in between workers
|
||||
# fetching of all tenant IDs. We accept this risk for now. Just re-running
|
||||
# the migration will fix the issue.
|
||||
sorted_tenant_schemas = sorted(tenant_schemas)
|
||||
|
||||
# Apply range filtering (0-based indexing)
|
||||
start_idx = start_range if start_range is not None else 0
|
||||
end_idx = end_range if end_range is not None else len(sorted_tenant_schemas)
|
||||
|
||||
# Ensure indices are within bounds
|
||||
start_idx = max(0, start_idx)
|
||||
end_idx = min(len(sorted_tenant_schemas), end_idx)
|
||||
|
||||
# Get the filtered tenant schemas
|
||||
filtered_tenant_schemas = sorted_tenant_schemas[start_idx:end_idx]
|
||||
|
||||
# Combine with non-tenant schemas and preserve original order
|
||||
filtered_tenants = []
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in filtered_tenant_schemas or tenant_id in non_tenant_schemas:
|
||||
filtered_tenants.append(tenant_id)
|
||||
|
||||
return filtered_tenants
|
||||
|
||||
|
||||
def get_schema_options() -> (
|
||||
tuple[bool, bool, bool, int | None, int | None, list[str] | None]
|
||||
):
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
for pair in arg.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
x_args[key.strip()] = value.strip()
|
||||
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
|
||||
if "=" in arg:
|
||||
key, value = arg.split("=", 1)
|
||||
x_args[key.strip()] = value.strip()
|
||||
else:
|
||||
raise ValueError(f"Invalid argument: {arg}")
|
||||
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
@@ -85,17 +141,81 @@ def get_schema_options() -> tuple[str, bool, bool, bool]:
|
||||
# only applies to online migrations
|
||||
continue_on_error = x_args.get("continue", "false").lower() == "true"
|
||||
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
and not upgrade_all_tenants
|
||||
):
|
||||
# Tenant range filtering
|
||||
tenant_range_start = None
|
||||
tenant_range_end = None
|
||||
|
||||
if "tenant_range_start" in x_args:
|
||||
try:
|
||||
tenant_range_start = int(x_args["tenant_range_start"])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid tenant_range_start value: {x_args['tenant_range_start']}. Must be an integer."
|
||||
)
|
||||
|
||||
if "tenant_range_end" in x_args:
|
||||
try:
|
||||
tenant_range_end = int(x_args["tenant_range_end"])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid tenant_range_end value: {x_args['tenant_range_end']}. Must be an integer."
|
||||
)
|
||||
|
||||
# Validate range
|
||||
if tenant_range_start is not None and tenant_range_end is not None:
|
||||
if tenant_range_start > tenant_range_end:
|
||||
raise ValueError(
|
||||
f"tenant_range_start ({tenant_range_start}) cannot be greater than tenant_range_end ({tenant_range_end})"
|
||||
)
|
||||
|
||||
# Specific schema names filtering (replaces both schema_name and the old tenant_ids approach)
|
||||
schemas = None
|
||||
if "schemas" in x_args:
|
||||
schema_names_str = x_args["schemas"].strip()
|
||||
if schema_names_str:
|
||||
# Split by comma and strip whitespace
|
||||
schemas = [
|
||||
name.strip() for name in schema_names_str.split(",") if name.strip()
|
||||
]
|
||||
if schemas:
|
||||
logger.info(f"Specific schema names specified: {schemas}")
|
||||
|
||||
# Validate that only one method is used at a time
|
||||
range_filtering = tenant_range_start is not None or tenant_range_end is not None
|
||||
specific_filtering = schemas is not None and len(schemas) > 0
|
||||
|
||||
if range_filtering and specific_filtering:
|
||||
raise ValueError(
|
||||
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
|
||||
"Please specify a tenant-specific schema."
|
||||
"Cannot use both tenant range filtering (tenant_range_start/tenant_range_end) "
|
||||
"and specific schema filtering (schemas) at the same time. "
|
||||
"Please use only one filtering method."
|
||||
)
|
||||
|
||||
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
|
||||
if upgrade_all_tenants and specific_filtering:
|
||||
raise ValueError(
|
||||
"Cannot use both upgrade_all_tenants=true and schemas at the same time. "
|
||||
"Use either upgrade_all_tenants=true for all tenants, or schemas for specific schemas."
|
||||
)
|
||||
|
||||
# If any filtering parameters are specified, we're not doing the default single schema migration
|
||||
if range_filtering:
|
||||
upgrade_all_tenants = True
|
||||
|
||||
# Validate multi-tenant requirements
|
||||
if MULTI_TENANT and not upgrade_all_tenants and not specific_filtering:
|
||||
raise ValueError(
|
||||
"In multi-tenant mode, you must specify either upgrade_all_tenants=true "
|
||||
"or provide schemas. Cannot run default migration."
|
||||
)
|
||||
|
||||
return (
|
||||
create_schema,
|
||||
upgrade_all_tenants,
|
||||
continue_on_error,
|
||||
tenant_range_start,
|
||||
tenant_range_end,
|
||||
schemas,
|
||||
)
|
||||
|
||||
|
||||
def do_run_migrations(
|
||||
@@ -142,12 +262,17 @@ def provide_iam_token_for_alembic(
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
(
|
||||
schema_name,
|
||||
create_schema,
|
||||
upgrade_all_tenants,
|
||||
continue_on_error,
|
||||
tenant_range_start,
|
||||
tenant_range_end,
|
||||
schemas,
|
||||
) = get_schema_options()
|
||||
|
||||
if not schemas and not MULTI_TENANT:
|
||||
schemas = [POSTGRES_DEFAULT_SCHEMA]
|
||||
|
||||
# without init_engine, subsequent engine calls fail hard intentionally
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
@@ -164,12 +289,50 @@ async def run_async_migrations() -> None:
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
if upgrade_all_tenants:
|
||||
if schemas:
|
||||
# Use specific schema names directly without fetching all tenants
|
||||
logger.info(f"Migrating specific schema names: {schemas}")
|
||||
|
||||
i_schema = 0
|
||||
num_schemas = len(schemas)
|
||||
for schema in schemas:
|
||||
i_schema += 1
|
||||
logger.info(
|
||||
f"Migrating schema: index={i_schema} num_schemas={num_schemas} schema={schema}"
|
||||
)
|
||||
try:
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
logger.error("--continue=true is not set, raising exception!")
|
||||
raise
|
||||
|
||||
logger.warning("--continue=true is set, continuing to next schema.")
|
||||
|
||||
elif upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
|
||||
filtered_tenant_schemas = filter_tenants_by_range(
|
||||
tenant_schemas, tenant_range_start, tenant_range_end
|
||||
)
|
||||
|
||||
if tenant_range_start is not None or tenant_range_end is not None:
|
||||
logger.info(
|
||||
f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}"
|
||||
)
|
||||
logger.info(
|
||||
f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}"
|
||||
)
|
||||
|
||||
i_tenant = 0
|
||||
num_tenants = len(tenant_schemas)
|
||||
for schema in tenant_schemas:
|
||||
num_tenants = len(filtered_tenant_schemas)
|
||||
for schema in filtered_tenant_schemas:
|
||||
i_tenant += 1
|
||||
logger.info(
|
||||
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
|
||||
@@ -190,17 +353,13 @@ async def run_async_migrations() -> None:
|
||||
logger.warning("--continue=true is set, continuing to next schema.")
|
||||
|
||||
else:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema_name,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema_name}: {e}")
|
||||
raise
|
||||
# This should not happen in the new design since we require either
|
||||
# upgrade_all_tenants=true or schemas in multi-tenant mode
|
||||
# and for non-multi-tenant mode, we should use schemas with the default schema
|
||||
raise ValueError(
|
||||
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
|
||||
"or schemas for specific schemas."
|
||||
)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@@ -221,10 +380,37 @@ def run_migrations_offline() -> None:
|
||||
# without init_engine, subsequent engine calls fail hard intentionally
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
|
||||
(
|
||||
create_schema,
|
||||
upgrade_all_tenants,
|
||||
continue_on_error,
|
||||
tenant_range_start,
|
||||
tenant_range_end,
|
||||
schemas,
|
||||
) = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
if schemas:
|
||||
# Use specific schema names directly without fetching all tenants
|
||||
logger.info(f"Migrating specific schema names: {schemas}")
|
||||
|
||||
for schema in schemas:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
elif upgrade_all_tenants:
|
||||
engine = create_async_engine(url)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
@@ -238,7 +424,19 @@ def run_migrations_offline() -> None:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
engine.sync_engine.dispose()
|
||||
|
||||
for schema in tenant_schemas:
|
||||
filtered_tenant_schemas = filter_tenants_by_range(
|
||||
tenant_schemas, tenant_range_start, tenant_range_end
|
||||
)
|
||||
|
||||
if tenant_range_start is not None or tenant_range_end is not None:
|
||||
logger.info(
|
||||
f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}"
|
||||
)
|
||||
logger.info(
|
||||
f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}"
|
||||
)
|
||||
|
||||
for schema in filtered_tenant_schemas:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
@@ -254,21 +452,12 @@ def run_migrations_offline() -> None:
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
else:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
# This should not happen in the new design
|
||||
raise ValueError(
|
||||
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
|
||||
"or schemas for specific schemas."
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
logger.info("run_migrations_online starting.")
|
||||
|
||||
121
backend/alembic/versions/03bf8be6b53a_rework_kg_config.py
Normal file
121
backend/alembic/versions/03bf8be6b53a_rework_kg_config.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""rework-kg-config
|
||||
|
||||
Revision ID: 03bf8be6b53a
|
||||
Revises: 65bc6e0f8500
|
||||
Create Date: 2025-06-16 10:52:34.815335
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import text
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03bf8be6b53a"
|
||||
down_revision = "65bc6e0f8500"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# get current config
|
||||
current_configs = (
|
||||
op.get_bind()
|
||||
.execute(text("SELECT kg_variable_name, kg_variable_values FROM kg_config"))
|
||||
.all()
|
||||
)
|
||||
current_config_dict = {
|
||||
config.kg_variable_name: (
|
||||
config.kg_variable_values[0]
|
||||
if config.kg_variable_name
|
||||
not in ("KG_VENDOR_DOMAINS", "KG_IGNORE_EMAIL_DOMAINS")
|
||||
else config.kg_variable_values
|
||||
)
|
||||
for config in current_configs
|
||||
if config.kg_variable_values
|
||||
}
|
||||
|
||||
# not using the KGConfigSettings model here in case it changes in the future
|
||||
kg_config_settings = json.dumps(
|
||||
{
|
||||
"KG_EXPOSED": current_config_dict.get("KG_EXPOSED", False),
|
||||
"KG_ENABLED": current_config_dict.get("KG_ENABLED", False),
|
||||
"KG_VENDOR": current_config_dict.get("KG_VENDOR", None),
|
||||
"KG_VENDOR_DOMAINS": current_config_dict.get("KG_VENDOR_DOMAINS", []),
|
||||
"KG_IGNORE_EMAIL_DOMAINS": current_config_dict.get(
|
||||
"KG_IGNORE_EMAIL_DOMAINS", []
|
||||
),
|
||||
"KG_COVERAGE_START": current_config_dict.get(
|
||||
"KG_COVERAGE_START",
|
||||
(datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d"),
|
||||
),
|
||||
"KG_MAX_COVERAGE_DAYS": current_config_dict.get("KG_MAX_COVERAGE_DAYS", 90),
|
||||
"KG_MAX_PARENT_RECURSION_DEPTH": current_config_dict.get(
|
||||
"KG_MAX_PARENT_RECURSION_DEPTH", 2
|
||||
),
|
||||
"KG_BETA_PERSONA_ID": current_config_dict.get("KG_BETA_PERSONA_ID", None),
|
||||
}
|
||||
)
|
||||
op.execute(
|
||||
f"INSERT INTO key_value_store (key, value) VALUES ('kg_config', '{kg_config_settings}')"
|
||||
)
|
||||
|
||||
# drop kg config table
|
||||
op.drop_table("kg_config")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# get current config
|
||||
current_config_dict = {
|
||||
"KG_EXPOSED": False,
|
||||
"KG_ENABLED": False,
|
||||
"KG_VENDOR": [],
|
||||
"KG_VENDOR_DOMAINS": [],
|
||||
"KG_IGNORE_EMAIL_DOMAINS": [],
|
||||
"KG_COVERAGE_START": (datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d"),
|
||||
"KG_MAX_COVERAGE_DAYS": 90,
|
||||
"KG_MAX_PARENT_RECURSION_DEPTH": 2,
|
||||
}
|
||||
current_configs = (
|
||||
op.get_bind()
|
||||
.execute(text("SELECT value FROM key_value_store WHERE key = 'kg_config'"))
|
||||
.one_or_none()
|
||||
)
|
||||
if current_configs is not None:
|
||||
current_config_dict.update(current_configs[0])
|
||||
insert_values = [
|
||||
{
|
||||
"kg_variable_name": name,
|
||||
"kg_variable_values": (
|
||||
[str(val).lower() if isinstance(val, bool) else str(val)]
|
||||
if not isinstance(val, list)
|
||||
else val
|
||||
),
|
||||
}
|
||||
for name, val in current_config_dict.items()
|
||||
]
|
||||
|
||||
op.create_table(
|
||||
"kg_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("kg_variable_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False),
|
||||
sa.UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"),
|
||||
)
|
||||
op.bulk_insert(
|
||||
sa.table(
|
||||
"kg_config",
|
||||
sa.column("kg_variable_name", sa.String),
|
||||
sa.column("kg_variable_values", postgresql.ARRAY(sa.String)),
|
||||
),
|
||||
insert_values,
|
||||
)
|
||||
|
||||
op.execute("DELETE FROM key_value_store WHERE key = 'kg_config'")
|
||||
@@ -0,0 +1,72 @@
|
||||
"""add federated connector tables
|
||||
|
||||
Revision ID: 0816326d83aa
|
||||
Revises: 12635f6655b7
|
||||
Create Date: 2025-06-29 14:09:45.109518
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0816326d83aa"
|
||||
down_revision = "12635f6655b7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create federated_connector table
|
||||
op.create_table(
|
||||
"federated_connector",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("source", sa.String(), nullable=False),
|
||||
sa.Column("credentials", sa.LargeBinary(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create federated_connector_oauth_token table
|
||||
op.create_table(
|
||||
"federated_connector_oauth_token",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("federated_connector_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("token", sa.LargeBinary(), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create federated_connector__document_set table
|
||||
op.create_table(
|
||||
"federated_connector__document_set",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("federated_connector_id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_set_id", sa.Integer(), nullable=False),
|
||||
sa.Column("entities", postgresql.JSONB(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_set_id"], ["document_set.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"federated_connector_id",
|
||||
"document_set_id",
|
||||
name="uq_federated_connector_document_set",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop tables in reverse order due to foreign key dependencies
|
||||
op.drop_table("federated_connector__document_set")
|
||||
op.drop_table("federated_connector_oauth_token")
|
||||
op.drop_table("federated_connector")
|
||||
596
backend/alembic/versions/12635f6655b7_drive_canonical_ids.py
Normal file
596
backend/alembic/versions/12635f6655b7_drive_canonical_ids.py
Normal file
@@ -0,0 +1,596 @@
|
||||
"""drive-canonical-ids
|
||||
|
||||
Revision ID: 12635f6655b7
|
||||
Revises: 58c50ef19f08
|
||||
Create Date: 2025-06-20 14:44:54.241159
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from httpx import HTTPStatusError
|
||||
import httpx
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.db.search_settings import SearchSettings
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.utils.logger import setup_logger
|
||||
import os
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "12635f6655b7"
|
||||
down_revision = "58c50ef19f08"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
SKIP_CANON_DRIVE_IDS = os.environ.get("SKIP_CANON_DRIVE_IDS", "true").lower() == "true"
|
||||
|
||||
|
||||
def active_search_settings() -> tuple[SearchSettings, SearchSettings | None]:
|
||||
result = op.get_bind().execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT * FROM search_settings WHERE status = 'PRESENT' ORDER BY id DESC LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
search_settings_fetch = result.fetchall()
|
||||
search_settings = (
|
||||
SearchSettings(**search_settings_fetch[0]._asdict())
|
||||
if search_settings_fetch
|
||||
else None
|
||||
)
|
||||
|
||||
result2 = op.get_bind().execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT * FROM search_settings WHERE status = 'FUTURE' ORDER BY id DESC LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
search_settings_future_fetch = result2.fetchall()
|
||||
search_settings_future = (
|
||||
SearchSettings(**search_settings_future_fetch[0]._asdict())
|
||||
if search_settings_future_fetch
|
||||
else None
|
||||
)
|
||||
|
||||
if not isinstance(search_settings, SearchSettings):
|
||||
raise RuntimeError(
|
||||
"current search settings is of type " + str(type(search_settings))
|
||||
)
|
||||
if (
|
||||
not isinstance(search_settings_future, SearchSettings)
|
||||
and search_settings_future is not None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"future search settings is of type " + str(type(search_settings_future))
|
||||
)
|
||||
|
||||
return search_settings, search_settings_future
|
||||
|
||||
|
||||
def normalize_google_drive_url(url: str) -> str:
|
||||
"""Remove query parameters from Google Drive URLs to create canonical document IDs.
|
||||
NOTE: copied from drive doc_conversion.py
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
parsed_url = parsed_url._replace(query="")
|
||||
spl_path = parsed_url.path.split("/")
|
||||
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
|
||||
spl_path.pop()
|
||||
parsed_url = parsed_url._replace(path="/".join(spl_path))
|
||||
# Remove query parameters and reconstruct URL
|
||||
return urlunparse(parsed_url)
|
||||
|
||||
|
||||
def get_google_drive_documents_from_database() -> list[dict]:
|
||||
"""Get all Google Drive documents from the database."""
|
||||
bind = op.get_bind()
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT d.id
|
||||
FROM document d
|
||||
JOIN document_by_connector_credential_pair dcc ON d.id = dcc.id
|
||||
JOIN connector_credential_pair cc ON dcc.connector_id = cc.connector_id
|
||||
AND dcc.credential_id = cc.credential_id
|
||||
JOIN connector c ON cc.connector_id = c.id
|
||||
WHERE c.source = 'GOOGLE_DRIVE'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
documents = []
|
||||
for row in result:
|
||||
documents.append({"document_id": row.id})
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
def update_document_id_in_database(
|
||||
old_doc_id: str, new_doc_id: str, index_name: str
|
||||
) -> None:
|
||||
"""Update document IDs in all relevant database tables using copy-and-swap approach."""
|
||||
bind = op.get_bind()
|
||||
|
||||
# print(f"Updating database tables for document {old_doc_id} -> {new_doc_id}")
|
||||
|
||||
# Check if new document ID already exists
|
||||
result = bind.execute(
|
||||
sa.text("SELECT COUNT(*) FROM document WHERE id = :new_id"),
|
||||
{"new_id": new_doc_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row and row[0] > 0:
|
||||
# print(f"Document with ID {new_doc_id} already exists, deleting old one")
|
||||
delete_document_from_db(old_doc_id, index_name)
|
||||
return
|
||||
|
||||
# Step 1: Create a new document row with the new ID (copy all fields from old row)
|
||||
# Use a conservative approach to handle columns that might not exist in all installations
|
||||
try:
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id,
|
||||
link, doc_updated_at, primary_owners, secondary_owners,
|
||||
external_user_emails, external_user_group_ids, is_public,
|
||||
chunk_count, last_modified, last_synced, kg_stage, kg_processing_time)
|
||||
SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id,
|
||||
link, doc_updated_at, primary_owners, secondary_owners,
|
||||
external_user_emails, external_user_group_ids, is_public,
|
||||
chunk_count, last_modified, last_synced, kg_stage, kg_processing_time
|
||||
FROM document
|
||||
WHERE id = :old_id
|
||||
"""
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated database tables for document {old_doc_id} -> {new_doc_id}")
|
||||
except Exception as e:
|
||||
# If the full INSERT fails, try a more basic version with only core columns
|
||||
logger.warning(f"Full INSERT failed, trying basic version: {e}")
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id,
|
||||
link, doc_updated_at, primary_owners, secondary_owners)
|
||||
SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id,
|
||||
link, doc_updated_at, primary_owners, secondary_owners
|
||||
FROM document
|
||||
WHERE id = :old_id
|
||||
"""
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
|
||||
# Step 2: Update all foreign key references to point to the new ID
|
||||
|
||||
# Update document_by_connector_credential_pair table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE document_by_connector_credential_pair SET id = :new_id WHERE id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated document_by_connector_credential_pair table for document {old_doc_id} -> {new_doc_id}")
|
||||
|
||||
# Update search_doc table (stores search results for chat replay)
|
||||
# This is critical for agent functionality
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE search_doc SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated search_doc table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update document_retrieval_feedback table (user feedback on documents)
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE document_retrieval_feedback SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated document_retrieval_feedback table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update document__tag table (document-tag relationships)
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE document__tag SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated document__tag table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update user_file table (user uploaded files linked to documents)
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE user_file SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated user_file table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update KG and chunk_stats tables (these may not exist in all installations)
|
||||
try:
|
||||
# Update kg_entity table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE kg_entity SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated kg_entity table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update kg_entity_extraction_staging table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE kg_entity_extraction_staging SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated kg_entity_extraction_staging table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update kg_relationship table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE kg_relationship SET source_document = :new_id WHERE source_document = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated kg_relationship table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update kg_relationship_extraction_staging table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE kg_relationship_extraction_staging SET source_document = :new_id WHERE source_document = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated kg_relationship_extraction_staging table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update chunk_stats table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE chunk_stats SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated chunk_stats table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update chunk_stats ID field which includes document_id
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chunk_stats
|
||||
SET id = REPLACE(id, :old_id, :new_id)
|
||||
WHERE id LIKE :old_id_pattern
|
||||
"""
|
||||
),
|
||||
{
|
||||
"new_id": new_doc_id,
|
||||
"old_id": old_doc_id,
|
||||
"old_id_pattern": f"{old_doc_id}__%",
|
||||
},
|
||||
)
|
||||
# print(f"Successfully updated chunk_stats ID field for document {old_doc_id} -> {new_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Some KG/chunk tables may not exist or failed to update: {e}")
|
||||
|
||||
# Step 3: Delete the old document row (this should now be safe since all FKs point to new row)
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM document WHERE id = :old_id"), {"old_id": old_doc_id}
|
||||
)
|
||||
# print(f"Successfully deleted document {old_doc_id} from database")
|
||||
|
||||
|
||||
def _visit_chunks(
|
||||
*,
|
||||
http_client: httpx.Client,
|
||||
index_name: str,
|
||||
selection: str,
|
||||
continuation: str | None = None,
|
||||
) -> tuple[list[dict], str | None]:
|
||||
"""Helper that calls the /document/v1 visit API once and returns (docs, next_token)."""
|
||||
|
||||
# Use the same URL as the document API, but with visit-specific params
|
||||
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
|
||||
params: dict[str, str] = {
|
||||
"selection": selection,
|
||||
"wantedDocumentCount": "1000",
|
||||
}
|
||||
if continuation:
|
||||
params["continuation"] = continuation
|
||||
|
||||
# print(f"Visiting chunks for selection '{selection}' with params {params}")
|
||||
resp = http_client.get(base_url, params=params, timeout=None)
|
||||
# print(f"Visited chunks for document {selection}")
|
||||
resp.raise_for_status()
|
||||
|
||||
payload = resp.json()
|
||||
return payload.get("documents", []), payload.get("continuation")
|
||||
|
||||
|
||||
def delete_document_chunks_from_vespa(index_name: str, doc_id: str) -> None:
|
||||
"""Delete all chunks for *doc_id* from Vespa using continuation-token paging (no offset)."""
|
||||
|
||||
total_deleted = 0
|
||||
# Use exact match instead of contains - Document Selector Language doesn't support contains
|
||||
selection = f'{index_name}.document_id=="{doc_id}"'
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
continuation: str | None = None
|
||||
while True:
|
||||
docs, continuation = _visit_chunks(
|
||||
http_client=http_client,
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
continuation=continuation,
|
||||
)
|
||||
|
||||
if not docs:
|
||||
break
|
||||
|
||||
for doc in docs:
|
||||
vespa_full_id = doc.get("id")
|
||||
if not vespa_full_id:
|
||||
continue
|
||||
|
||||
vespa_doc_uuid = vespa_full_id.split("::")[-1]
|
||||
delete_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
|
||||
|
||||
try:
|
||||
resp = http_client.delete(delete_url)
|
||||
resp.raise_for_status()
|
||||
total_deleted += 1
|
||||
except Exception as e:
|
||||
print(f"Failed to delete chunk {vespa_doc_uuid}: {e}")
|
||||
|
||||
if not continuation:
|
||||
break
|
||||
|
||||
|
||||
def update_document_id_in_vespa(
|
||||
index_name: str, old_doc_id: str, new_doc_id: str
|
||||
) -> None:
|
||||
"""Update all chunks' document_id field from *old_doc_id* to *new_doc_id* using continuation paging."""
|
||||
|
||||
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
|
||||
|
||||
# Use exact match instead of contains - Document Selector Language doesn't support contains
|
||||
selection = f'{index_name}.document_id=="{old_doc_id}"'
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
continuation: str | None = None
|
||||
while True:
|
||||
# print(f"Visiting chunks for document {old_doc_id} -> {new_doc_id}")
|
||||
docs, continuation = _visit_chunks(
|
||||
http_client=http_client,
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
continuation=continuation,
|
||||
)
|
||||
|
||||
if not docs:
|
||||
break
|
||||
|
||||
for doc in docs:
|
||||
vespa_full_id = doc.get("id")
|
||||
if not vespa_full_id:
|
||||
continue
|
||||
|
||||
vespa_doc_uuid = vespa_full_id.split("::")[-1]
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
|
||||
|
||||
update_request = {
|
||||
"fields": {"document_id": {"assign": clean_new_doc_id}}
|
||||
}
|
||||
|
||||
try:
|
||||
resp = http_client.put(vespa_url, json=update_request)
|
||||
resp.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"Failed to update chunk {vespa_doc_uuid}: {e}")
|
||||
raise
|
||||
|
||||
if not continuation:
|
||||
break
|
||||
|
||||
|
||||
def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
|
||||
# Delete all foreign key references first, then delete the document
|
||||
try:
|
||||
bind = op.get_bind()
|
||||
|
||||
# Delete from agent-related tables first (order matters due to foreign keys)
|
||||
# Delete from agent__sub_query__search_doc first since it references search_doc
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM agent__sub_query__search_doc
|
||||
WHERE search_doc_id IN (
|
||||
SELECT id FROM search_doc WHERE document_id = :doc_id
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete from chat_message__search_doc
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE search_doc_id IN (
|
||||
SELECT id FROM search_doc WHERE document_id = :doc_id
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Now we can safely delete from search_doc
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM search_doc WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete from document_by_connector_credential_pair
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"DELETE FROM document_by_connector_credential_pair WHERE id = :doc_id"
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete from other tables that reference this document
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"DELETE FROM document_retrieval_feedback WHERE document_id = :doc_id"
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM document__tag WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM user_file WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete from KG tables if they exist
|
||||
try:
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM kg_entity WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"DELETE FROM kg_entity_extraction_staging WHERE document_id = :doc_id"
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM kg_relationship WHERE source_document = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"DELETE FROM kg_relationship_extraction_staging WHERE source_document = :doc_id"
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM chunk_stats WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM chunk_stats WHERE id LIKE :doc_id_pattern"),
|
||||
{"doc_id_pattern": f"{current_doc_id}__%"},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Some KG/chunk tables may not exist or failed to delete from: {e}"
|
||||
)
|
||||
|
||||
# Finally delete the document itself
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM document WHERE id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete chunks from vespa
|
||||
delete_document_chunks_from_vespa(index_name, current_doc_id)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to delete duplicate document {current_doc_id}: {e}")
|
||||
# Continue with other documents instead of failing the entire migration
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if SKIP_CANON_DRIVE_IDS:
|
||||
return
|
||||
current_search_settings, future_search_settings = active_search_settings()
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings,
|
||||
future_search_settings,
|
||||
)
|
||||
|
||||
# Get the index name
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
else:
|
||||
# Default index name if we can't get it from the document_index
|
||||
index_name = "danswer_index"
|
||||
|
||||
# Get all Google Drive documents from the database (this is faster and more reliable)
|
||||
gdrive_documents = get_google_drive_documents_from_database()
|
||||
|
||||
if not gdrive_documents:
|
||||
return
|
||||
|
||||
# Track normalized document IDs to detect duplicates
|
||||
all_normalized_doc_ids = set()
|
||||
updated_count = 0
|
||||
|
||||
for doc_info in gdrive_documents:
|
||||
current_doc_id = doc_info["document_id"]
|
||||
normalized_doc_id = normalize_google_drive_url(current_doc_id)
|
||||
|
||||
print(f"Processing document {current_doc_id} -> {normalized_doc_id}")
|
||||
# Check for duplicates
|
||||
if normalized_doc_id in all_normalized_doc_ids:
|
||||
# print(f"Deleting duplicate document {current_doc_id}")
|
||||
delete_document_from_db(current_doc_id, index_name)
|
||||
continue
|
||||
|
||||
all_normalized_doc_ids.add(normalized_doc_id)
|
||||
|
||||
# If the document ID already doesn't have query parameters, skip it
|
||||
if current_doc_id == normalized_doc_id:
|
||||
# print(f"Skipping document {current_doc_id} -> {normalized_doc_id} because it already has no query parameters")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Update both database and Vespa in order
|
||||
# Database first to ensure consistency
|
||||
update_document_id_in_database(
|
||||
current_doc_id, normalized_doc_id, index_name
|
||||
)
|
||||
|
||||
# For Vespa, we can now use the original document IDs since we're using contains matching
|
||||
update_document_id_in_vespa(index_name, current_doc_id, normalized_doc_id)
|
||||
updated_count += 1
|
||||
# print(f"Finished updating document {current_doc_id} -> {normalized_doc_id}")
|
||||
except Exception as e:
|
||||
print(f"Failed to update document {current_doc_id}: {e}")
|
||||
|
||||
if isinstance(e, HTTPStatusError):
|
||||
print(f"HTTPStatusError: {e}")
|
||||
print(f"Response: {e.response.text}")
|
||||
print(f"Status: {e.response.status_code}")
|
||||
print(f"Headers: {e.response.headers}")
|
||||
print(f"Request: {e.request.url}")
|
||||
print(f"Request headers: {e.request.headers}")
|
||||
# Note: Rollback is complex with copy-and-swap approach since the old document is already deleted
|
||||
# In case of failure, manual intervention may be required
|
||||
# Continue with other documents instead of failing the entire migration
|
||||
continue
|
||||
|
||||
logger.info(f"Migration complete. Updated {updated_count} Google Drive documents")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# this is a one way migration, so no downgrade.
|
||||
# It wouldn't make sense to store the extra query parameters
|
||||
# and duplicate documents to allow a reversal.
|
||||
pass
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Add foreign key to user__external_user_group_id
|
||||
|
||||
Revision ID: 238b84885828
|
||||
Revises: a7688ab35c45
|
||||
Create Date: 2025-05-19 17:15:33.424584
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "238b84885828"
|
||||
down_revision = "a7688ab35c45"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First, clean up any entries that don't have a valid cc_pair_id
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM user__external_user_group_id
|
||||
WHERE cc_pair_id NOT IN (SELECT id FROM connector_credential_pair)
|
||||
"""
|
||||
)
|
||||
|
||||
# Add foreign key constraint with cascade delete
|
||||
op.create_foreign_key(
|
||||
"fk_user__external_user_group_id_cc_pair_id",
|
||||
"user__external_user_group_id",
|
||||
"connector_credential_pair",
|
||||
["cc_pair_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_user__external_user_group_id_cc_pair_id",
|
||||
"user__external_user_group_id",
|
||||
type_="foreignkey",
|
||||
)
|
||||
@@ -144,27 +144,34 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("TRUNCATE TABLE index_attempt")
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("input_type", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("source", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"connector_specific_config",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if the constraint exists before dropping
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
existing_columns = {col["name"] for col in inspector.get_columns("index_attempt")}
|
||||
|
||||
if "input_type" not in existing_columns:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("input_type", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
)
|
||||
|
||||
if "source" not in existing_columns:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("source", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
)
|
||||
|
||||
if "connector_specific_config" not in existing_columns:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"connector_specific_config",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if the constraint exists before dropping
|
||||
constraints = inspector.get_foreign_keys("index_attempt")
|
||||
|
||||
if any(
|
||||
@@ -183,8 +190,12 @@ def downgrade() -> None:
|
||||
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_table("connector_credential_pair")
|
||||
op.drop_table("credential")
|
||||
op.drop_table("connector")
|
||||
if "credential_id" in existing_columns:
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
|
||||
if "connector_id" in existing_columns:
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS connector_credential_pair CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS credential CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS connector CASCADE")
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
"""add_indexing_coordination
|
||||
|
||||
Revision ID: 2f95e36923e6
|
||||
Revises: 0816326d83aa
|
||||
Create Date: 2025-07-10 16:17:57.762182
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2f95e36923e6"
|
||||
down_revision = "0816326d83aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add database-based coordination fields (replacing Redis fencing)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("celery_task_id", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"cancellation_requested",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
# Add batch coordination fields (replacing FileStore state)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("total_batches", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"completed_batches", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"total_failures_batch_level",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("total_chunks", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Progress tracking for stall detection
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("last_progress_time", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"last_batches_completed_count",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
|
||||
# Heartbeat tracking for worker liveness detection
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"heartbeat_counter", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"last_heartbeat_value", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("last_heartbeat_time", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
# Add index for coordination queries
|
||||
op.create_index(
|
||||
"ix_index_attempt_active_coordination",
|
||||
"index_attempt",
|
||||
["connector_credential_pair_id", "search_settings_id", "status"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the new index
|
||||
op.drop_index("ix_index_attempt_active_coordination", table_name="index_attempt")
|
||||
|
||||
# Remove the new columns
|
||||
op.drop_column("index_attempt", "last_batches_completed_count")
|
||||
op.drop_column("index_attempt", "last_progress_time")
|
||||
op.drop_column("index_attempt", "last_heartbeat_time")
|
||||
op.drop_column("index_attempt", "last_heartbeat_value")
|
||||
op.drop_column("index_attempt", "heartbeat_counter")
|
||||
op.drop_column("index_attempt", "total_chunks")
|
||||
op.drop_column("index_attempt", "total_failures_batch_level")
|
||||
op.drop_column("index_attempt", "completed_batches")
|
||||
op.drop_column("index_attempt", "total_batches")
|
||||
op.drop_column("index_attempt", "cancellation_requested")
|
||||
op.drop_column("index_attempt", "celery_task_id")
|
||||
@@ -0,0 +1,136 @@
|
||||
"""update_kg_trigger_functions
|
||||
|
||||
Revision ID: 36e9220ab794
|
||||
Revises: c9e2cd766c29
|
||||
Create Date: 2025-06-22 17:33:25.833733
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "36e9220ab794"
|
||||
down_revision = "c9e2cd766c29"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _get_tenant_contextvar(session: Session) -> str:
|
||||
"""Get the current schema for the migration"""
|
||||
current_tenant = session.execute(text("SELECT current_schema()")).scalar()
|
||||
if isinstance(current_tenant, str):
|
||||
return current_tenant
|
||||
else:
|
||||
raise ValueError("Current tenant is not a string")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
# Create kg_entity trigger to update kg_entity.name and its trigrams
|
||||
tenant_id = _get_tenant_contextvar(session)
|
||||
alphanum_pattern = r"[^a-z0-9]+"
|
||||
truncate_length = 1000
|
||||
function = "update_kg_entity_name"
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE OR REPLACE FUNCTION "{tenant_id}".{function}()
|
||||
RETURNS TRIGGER AS $$
|
||||
DECLARE
|
||||
name text;
|
||||
cleaned_name text;
|
||||
BEGIN
|
||||
-- Set name to semantic_id if document_id is not NULL
|
||||
IF NEW.document_id IS NOT NULL THEN
|
||||
SELECT lower(semantic_id) INTO name
|
||||
FROM "{tenant_id}".document
|
||||
WHERE id = NEW.document_id;
|
||||
ELSE
|
||||
name = lower(NEW.name);
|
||||
END IF;
|
||||
|
||||
-- Clean name and truncate if too long
|
||||
cleaned_name = regexp_replace(
|
||||
name,
|
||||
'{alphanum_pattern}', '', 'g'
|
||||
);
|
||||
IF length(cleaned_name) > {truncate_length} THEN
|
||||
cleaned_name = left(cleaned_name, {truncate_length});
|
||||
END IF;
|
||||
|
||||
-- Set name and name trigrams
|
||||
NEW.name = name;
|
||||
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
trigger = f"{function}_trigger"
|
||||
op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".kg_entity')
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TRIGGER {trigger}
|
||||
BEFORE INSERT OR UPDATE OF name
|
||||
ON "{tenant_id}".kg_entity
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION "{tenant_id}".{function}();
|
||||
"""
|
||||
)
|
||||
|
||||
# Create kg_entity trigger to update kg_entity.name and its trigrams
|
||||
function = "update_kg_entity_name_from_doc"
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE OR REPLACE FUNCTION "{tenant_id}".{function}()
|
||||
RETURNS TRIGGER AS $$
|
||||
DECLARE
|
||||
doc_name text;
|
||||
cleaned_name text;
|
||||
BEGIN
|
||||
doc_name = lower(NEW.semantic_id);
|
||||
|
||||
-- Clean name and truncate if too long
|
||||
cleaned_name = regexp_replace(
|
||||
doc_name,
|
||||
'{alphanum_pattern}', '', 'g'
|
||||
);
|
||||
IF length(cleaned_name) > {truncate_length} THEN
|
||||
cleaned_name = left(cleaned_name, {truncate_length});
|
||||
END IF;
|
||||
|
||||
-- Set name and name trigrams for all entities referencing this document
|
||||
UPDATE "{tenant_id}".kg_entity
|
||||
SET
|
||||
name = doc_name,
|
||||
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
|
||||
WHERE document_id = NEW.id;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
trigger = f"{function}_trigger"
|
||||
op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".document')
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TRIGGER {trigger}
|
||||
AFTER UPDATE OF semantic_id
|
||||
ON "{tenant_id}".document
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION "{tenant_id}".{function}();
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -21,22 +21,14 @@ depends_on = None
|
||||
# an outage by creating an index without using CONCURRENTLY. This migration:
|
||||
#
|
||||
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
|
||||
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
|
||||
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
|
||||
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
|
||||
# (see: https://github.com/sqlalchemy/alembic/issues/277)
|
||||
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||
# 2. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||
# 3. Note: CONCURRENTLY was removed due to operational issues
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First, drop any existing indexes to avoid conflicts
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_tsv;")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_session_desc_tsv;")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
|
||||
# Drop existing columns if they exist
|
||||
@@ -52,12 +44,9 @@ def upgrade() -> None:
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit the current transaction before creating concurrent indexes
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_message_tsv
|
||||
ON chat_message
|
||||
USING GIN (message_tsv)
|
||||
"""
|
||||
@@ -72,12 +61,9 @@ def upgrade() -> None:
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit again before creating the second concurrent index
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
ON chat_session
|
||||
USING GIN (description_tsv)
|
||||
"""
|
||||
@@ -85,12 +71,9 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the indexes first (use CONCURRENTLY for dropping too)
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
# Drop the indexes first
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_tsv;")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
# Then drop the columns
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add_doc_metadata_field_in_document_model
|
||||
|
||||
Revision ID: 3fc5d75723b3
|
||||
Revises: 2f95e36923e6
|
||||
Create Date: 2025-07-28 18:45:37.985406
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3fc5d75723b3"
|
||||
down_revision = "2f95e36923e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"doc_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "doc_metadata")
|
||||
@@ -0,0 +1,691 @@
|
||||
"""create knowledge graph tables
|
||||
|
||||
Revision ID: 495cb26ce93e
|
||||
Revises: ca04500b9ee8
|
||||
Create Date: 2025-03-19 08:51:14.341989
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "495cb26ce93e"
|
||||
down_revision = "ca04500b9ee8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
# Create a new permission-less user to be later used for knowledge graph queries.
|
||||
# The user will later get temporary read privileges for a specific view that will be
|
||||
# ad hoc generated specific to a knowledge graph query.
|
||||
#
|
||||
# Note: in order for the migration to run, the DB_READONLY_USER and DB_READONLY_PASSWORD
|
||||
# environment variables MUST be set. Otherwise, an exception will be raised.
|
||||
|
||||
if not MULTI_TENANT:
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
# Create read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is created in the alembic_tenants migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Grant usage on current schema to readonly user
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
EXECUTE format('GRANT USAGE ON SCHEMA %I TO %I', current_schema(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_config CASCADE")
|
||||
op.create_table(
|
||||
"kg_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("kg_variable_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False),
|
||||
sa.UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"),
|
||||
)
|
||||
|
||||
# Insert initial data into kg_config table
|
||||
op.bulk_insert(
|
||||
sa.table(
|
||||
"kg_config",
|
||||
sa.column("kg_variable_name", sa.String),
|
||||
sa.column("kg_variable_values", postgresql.ARRAY(sa.String)),
|
||||
),
|
||||
[
|
||||
{"kg_variable_name": "KG_EXPOSED", "kg_variable_values": ["false"]},
|
||||
{"kg_variable_name": "KG_ENABLED", "kg_variable_values": ["false"]},
|
||||
{"kg_variable_name": "KG_VENDOR", "kg_variable_values": []},
|
||||
{"kg_variable_name": "KG_VENDOR_DOMAINS", "kg_variable_values": []},
|
||||
{"kg_variable_name": "KG_IGNORE_EMAIL_DOMAINS", "kg_variable_values": []},
|
||||
{
|
||||
"kg_variable_name": "KG_EXTRACTION_IN_PROGRESS",
|
||||
"kg_variable_values": ["false"],
|
||||
},
|
||||
{
|
||||
"kg_variable_name": "KG_CLUSTERING_IN_PROGRESS",
|
||||
"kg_variable_values": ["false"],
|
||||
},
|
||||
{
|
||||
"kg_variable_name": "KG_COVERAGE_START",
|
||||
"kg_variable_values": [
|
||||
(datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d")
|
||||
],
|
||||
},
|
||||
{"kg_variable_name": "KG_MAX_COVERAGE_DAYS", "kg_variable_values": ["90"]},
|
||||
{
|
||||
"kg_variable_name": "KG_MAX_PARENT_RECURSION_DEPTH",
|
||||
"kg_variable_values": ["2"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_entity_type CASCADE")
|
||||
op.create_table(
|
||||
"kg_entity_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column("grounding", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"attributes",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("deep_extraction", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column("grounded_source_name", sa.String(), nullable=True),
|
||||
sa.Column("entity_values", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_relationship_type CASCADE")
|
||||
# Create KGRelationshipType table
|
||||
op.create_table(
|
||||
"kg_relationship_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"source_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"target_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_relationship_type_extraction_staging CASCADE")
|
||||
# Create KGRelationshipTypeExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_relationship_type_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"source_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"target_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("transferred", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_entity CASCADE")
|
||||
|
||||
# Create KGEntity table
|
||||
op.create_table(
|
||||
"kg_entity",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("entity_class", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_key", sa.String(), nullable=True, index=True),
|
||||
sa.Column("name_trigrams", postgresql.ARRAY(sa.String(3)), nullable=True),
|
||||
sa.Column("document_id", sa.String(), nullable=True, index=True),
|
||||
sa.Column(
|
||||
"alternative_names",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"keywords",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column(
|
||||
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
|
||||
sa.UniqueConstraint(
|
||||
"name",
|
||||
"entity_type_id_name",
|
||||
"document_id",
|
||||
name="uq_kg_entity_name_type_doc",
|
||||
),
|
||||
)
|
||||
op.create_index("ix_entity_type_acl", "kg_entity", ["entity_type_id_name", "acl"])
|
||||
op.create_index(
|
||||
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_entity_extraction_staging CASCADE")
|
||||
# Create KGEntityExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_entity_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("document_id", sa.String(), nullable=True, index=True),
|
||||
sa.Column(
|
||||
"alternative_names",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"keywords",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column(
|
||||
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("transferred_id_name", sa.String(), nullable=True, default=None),
|
||||
sa.Column("entity_class", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_key", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
|
||||
sa.Column("parent_key", sa.String(), nullable=True, index=True),
|
||||
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_entity_extraction_staging_acl",
|
||||
"kg_entity_extraction_staging",
|
||||
["entity_type_id_name", "acl"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_entity_extraction_staging_name_search",
|
||||
"kg_entity_extraction_staging",
|
||||
["name", "entity_type_id_name"],
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_relationship CASCADE")
|
||||
# Create KGRelationship table
|
||||
op.create_table(
|
||||
"kg_relationship",
|
||||
sa.Column("id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_document", sa.String(), nullable=True, index=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["source_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
|
||||
sa.ForeignKeyConstraint(
|
||||
["relationship_type_id_name"], ["kg_relationship_type.id_name"]
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id_name", "source_document"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_relationship_extraction_staging CASCADE")
|
||||
# Create KGRelationshipExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_relationship_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_document", sa.String(), nullable=True, index=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("transferred", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_node"], ["kg_entity_extraction_staging.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_node"], ["kg_entity_extraction_staging.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
|
||||
sa.ForeignKeyConstraint(
|
||||
["relationship_type_id_name"],
|
||||
["kg_relationship_type_extraction_staging.id_name"],
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_extraction_staging_source_target_type",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id_name", "source_document"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kg_relationship_extraction_staging_nodes",
|
||||
"kg_relationship_extraction_staging",
|
||||
["source_node", "target_node"],
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_term CASCADE")
|
||||
# Create KGTerm table
|
||||
op.create_table(
|
||||
"kg_term",
|
||||
sa.Column("id_term", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column(
|
||||
"entity_types",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
)
|
||||
op.create_index("ix_search_term_entities", "kg_term", ["entity_types"])
|
||||
op.create_index("ix_search_term_term", "kg_term", ["id_term"])
|
||||
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_stage", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_processing_time", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column(
|
||||
"kg_processing_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column(
|
||||
"kg_coverage_days",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
server_default=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Create GIN index for clustering and normalization
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kg_entity_clustering_trigrams "
|
||||
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA}.gin_trgm_ops)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams "
|
||||
"ON kg_entity USING GIN (name_trigrams)"
|
||||
)
|
||||
|
||||
# Create kg_entity trigger to update kg_entity.name and its trigrams
|
||||
alphanum_pattern = r"[^a-z0-9]+"
|
||||
truncate_length = 1000
|
||||
function = "update_kg_entity_name"
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE OR REPLACE FUNCTION {function}()
|
||||
RETURNS TRIGGER AS $$
|
||||
DECLARE
|
||||
name text;
|
||||
cleaned_name text;
|
||||
BEGIN
|
||||
-- Set name to semantic_id if document_id is not NULL
|
||||
IF NEW.document_id IS NOT NULL THEN
|
||||
SELECT lower(semantic_id) INTO name
|
||||
FROM document
|
||||
WHERE id = NEW.document_id;
|
||||
ELSE
|
||||
name = lower(NEW.name);
|
||||
END IF;
|
||||
|
||||
-- Clean name and truncate if too long
|
||||
cleaned_name = regexp_replace(
|
||||
name,
|
||||
'{alphanum_pattern}', '', 'g'
|
||||
);
|
||||
IF length(cleaned_name) > {truncate_length} THEN
|
||||
cleaned_name = left(cleaned_name, {truncate_length});
|
||||
END IF;
|
||||
|
||||
-- Set name and name trigrams
|
||||
NEW.name = name;
|
||||
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
trigger = f"{function}_trigger"
|
||||
op.execute(f"DROP TRIGGER IF EXISTS {trigger} ON kg_entity")
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TRIGGER {trigger}
|
||||
BEFORE INSERT OR UPDATE OF name
|
||||
ON kg_entity
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION {function}();
|
||||
"""
|
||||
)
|
||||
|
||||
# Create kg_entity trigger to update kg_entity.name and its trigrams
|
||||
function = "update_kg_entity_name_from_doc"
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE OR REPLACE FUNCTION {function}()
|
||||
RETURNS TRIGGER AS $$
|
||||
DECLARE
|
||||
doc_name text;
|
||||
cleaned_name text;
|
||||
BEGIN
|
||||
doc_name = lower(NEW.semantic_id);
|
||||
|
||||
-- Clean name and truncate if too long
|
||||
cleaned_name = regexp_replace(
|
||||
doc_name,
|
||||
'{alphanum_pattern}', '', 'g'
|
||||
);
|
||||
IF length(cleaned_name) > {truncate_length} THEN
|
||||
cleaned_name = left(cleaned_name, {truncate_length});
|
||||
END IF;
|
||||
|
||||
-- Set name and name trigrams for all entities referencing this document
|
||||
UPDATE kg_entity
|
||||
SET
|
||||
name = doc_name,
|
||||
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
|
||||
WHERE document_id = NEW.id;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
trigger = f"{function}_trigger"
|
||||
op.execute(f"DROP TRIGGER IF EXISTS {trigger} ON document")
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TRIGGER {trigger}
|
||||
AFTER UPDATE OF semantic_id
|
||||
ON document
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION {function}();
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
# Drop all views that start with 'kg_'
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
view_name text;
|
||||
BEGIN
|
||||
FOR view_name IN
|
||||
SELECT c.relname
|
||||
FROM pg_catalog.pg_class c
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind = 'v'
|
||||
AND n.nspname = current_schema()
|
||||
AND c.relname LIKE 'kg_relationships_with_access%'
|
||||
LOOP
|
||||
EXECUTE 'DROP VIEW IF EXISTS ' || quote_ident(view_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
view_name text;
|
||||
BEGIN
|
||||
FOR view_name IN
|
||||
SELECT c.relname
|
||||
FROM pg_catalog.pg_class c
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind = 'v'
|
||||
AND n.nspname = current_schema()
|
||||
AND c.relname LIKE 'allowed_docs%'
|
||||
LOOP
|
||||
EXECUTE 'DROP VIEW IF EXISTS ' || quote_ident(view_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
for table, function in (
|
||||
("kg_entity", "update_kg_entity_name"),
|
||||
("document", "update_kg_entity_name_from_doc"),
|
||||
):
|
||||
op.execute(f"DROP TRIGGER IF EXISTS {function}_trigger ON {table}")
|
||||
op.execute(f"DROP FUNCTION IF EXISTS {function}()")
|
||||
|
||||
# Drop index
|
||||
op.execute("DROP INDEX IF EXISTS idx_kg_entity_clustering_trigrams")
|
||||
op.execute("DROP INDEX IF EXISTS idx_kg_entity_normalization_trigrams")
|
||||
|
||||
# Drop tables in reverse order of creation to handle dependencies
|
||||
op.drop_table("kg_term")
|
||||
op.drop_table("kg_relationship")
|
||||
op.drop_table("kg_entity")
|
||||
op.drop_table("kg_relationship_type")
|
||||
op.drop_table("kg_relationship_extraction_staging")
|
||||
op.drop_table("kg_relationship_type_extraction_staging")
|
||||
op.drop_table("kg_entity_extraction_staging")
|
||||
op.drop_table("kg_entity_type")
|
||||
op.drop_column("connector", "kg_processing_enabled")
|
||||
op.drop_column("connector", "kg_coverage_days")
|
||||
op.drop_column("document", "kg_stage")
|
||||
op.drop_column("document", "kg_processing_time")
|
||||
op.drop_table("kg_config")
|
||||
|
||||
# Revoke usage on current schema for the readonly user
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA %I FROM %I', current_schema(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if not MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
@@ -0,0 +1,380 @@
|
||||
"""merge_default_assistants_into_unified
|
||||
|
||||
Revision ID: 505c488f6662
|
||||
Revises: d09fc20a3c66
|
||||
Create Date: 2025-09-09 19:00:56.816626
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import NamedTuple
|
||||
from uuid import UUID
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "505c488f6662"
|
||||
down_revision = "d09fc20a3c66"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Constants for the unified assistant
|
||||
UNIFIED_ASSISTANT_NAME = "Assistant"
|
||||
UNIFIED_ASSISTANT_DESCRIPTION = (
|
||||
"Your AI assistant with search, web browsing, and image generation capabilities."
|
||||
)
|
||||
UNIFIED_ASSISTANT_NUM_CHUNKS = 25
|
||||
UNIFIED_ASSISTANT_DISPLAY_PRIORITY = 0
|
||||
UNIFIED_ASSISTANT_LLM_FILTER_EXTRACTION = True
|
||||
UNIFIED_ASSISTANT_LLM_RELEVANCE_FILTER = False
|
||||
UNIFIED_ASSISTANT_RECENCY_BIAS = "AUTO" # NOTE: needs to be capitalized
|
||||
UNIFIED_ASSISTANT_CHUNKS_ABOVE = 0
|
||||
UNIFIED_ASSISTANT_CHUNKS_BELOW = 0
|
||||
UNIFIED_ASSISTANT_DATETIME_AWARE = True
|
||||
|
||||
# NOTE: tool specific prompts are handled on the fly and automatically injected
|
||||
# into the prompt before passing to the LLM.
|
||||
DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the \
|
||||
user's intent, ask clarifying questions when needed, think step-by-step through complex problems, \
|
||||
provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always \
|
||||
prioritize being truthful, nuanced, insightful, and efficient.
|
||||
The current date is [[CURRENT_DATETIME]]
|
||||
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make \
|
||||
your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, \
|
||||
symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use Markdown horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
""".strip()
|
||||
|
||||
|
||||
INSERT_DICT: dict[str, Any] = {
|
||||
"name": UNIFIED_ASSISTANT_NAME,
|
||||
"description": UNIFIED_ASSISTANT_DESCRIPTION,
|
||||
"system_prompt": DEFAULT_SYSTEM_PROMPT,
|
||||
"num_chunks": UNIFIED_ASSISTANT_NUM_CHUNKS,
|
||||
"display_priority": UNIFIED_ASSISTANT_DISPLAY_PRIORITY,
|
||||
"llm_filter_extraction": UNIFIED_ASSISTANT_LLM_FILTER_EXTRACTION,
|
||||
"llm_relevance_filter": UNIFIED_ASSISTANT_LLM_RELEVANCE_FILTER,
|
||||
"recency_bias": UNIFIED_ASSISTANT_RECENCY_BIAS,
|
||||
"chunks_above": UNIFIED_ASSISTANT_CHUNKS_ABOVE,
|
||||
"chunks_below": UNIFIED_ASSISTANT_CHUNKS_BELOW,
|
||||
"datetime_aware": UNIFIED_ASSISTANT_DATETIME_AWARE,
|
||||
}
|
||||
|
||||
GENERAL_ASSISTANT_ID = -1
|
||||
ART_ASSISTANT_ID = -3
|
||||
|
||||
|
||||
class UserRow(NamedTuple):
|
||||
"""Typed representation of user row from database query."""
|
||||
|
||||
id: UUID
|
||||
chosen_assistants: list[int] | None
|
||||
visible_assistants: list[int] | None
|
||||
hidden_assistants: list[int] | None
|
||||
pinned_assistants: list[int] | None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
|
||||
# Add tools to the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": web_search_tool[0]},
|
||||
)
|
||||
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
@@ -0,0 +1,90 @@
|
||||
"""add stale column to external user group tables
|
||||
|
||||
Revision ID: 58c50ef19f08
|
||||
Revises: 7b9b952abdf6
|
||||
Create Date: 2025-06-25 14:08:14.162380
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "58c50ef19f08"
|
||||
down_revision = "7b9b952abdf6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the stale column with default value False to user__external_user_group_id
|
||||
op.add_column(
|
||||
"user__external_user_group_id",
|
||||
sa.Column("stale", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
|
||||
# Create index for efficient querying of stale rows by cc_pair_id
|
||||
op.create_index(
|
||||
"ix_user__external_user_group_id_cc_pair_id_stale",
|
||||
"user__external_user_group_id",
|
||||
["cc_pair_id", "stale"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# Create index for efficient querying of all stale rows
|
||||
op.create_index(
|
||||
"ix_user__external_user_group_id_stale",
|
||||
"user__external_user_group_id",
|
||||
["stale"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# Add the stale column with default value False to public_external_user_group
|
||||
op.add_column(
|
||||
"public_external_user_group",
|
||||
sa.Column("stale", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
|
||||
# Create index for efficient querying of stale rows by cc_pair_id
|
||||
op.create_index(
|
||||
"ix_public_external_user_group_cc_pair_id_stale",
|
||||
"public_external_user_group",
|
||||
["cc_pair_id", "stale"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# Create index for efficient querying of all stale rows
|
||||
op.create_index(
|
||||
"ix_public_external_user_group_stale",
|
||||
"public_external_user_group",
|
||||
["stale"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the indices for public_external_user_group first
|
||||
op.drop_index(
|
||||
"ix_public_external_user_group_stale", table_name="public_external_user_group"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_public_external_user_group_cc_pair_id_stale",
|
||||
table_name="public_external_user_group",
|
||||
)
|
||||
|
||||
# Drop the stale column from public_external_user_group
|
||||
op.drop_column("public_external_user_group", "stale")
|
||||
|
||||
# Drop the indices for user__external_user_group_id
|
||||
op.drop_index(
|
||||
"ix_user__external_user_group_id_stale",
|
||||
table_name="user__external_user_group_id",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_user__external_user_group_id_cc_pair_id_stale",
|
||||
table_name="user__external_user_group_id",
|
||||
)
|
||||
|
||||
# Drop the stale column from user__external_user_group_id
|
||||
op.drop_column("user__external_user_group_id", "stale")
|
||||
@@ -0,0 +1,115 @@
|
||||
"""add research agent database tables and chat message research fields
|
||||
|
||||
Revision ID: 5ae8240accb3
|
||||
Revises: b558f51620b4
|
||||
Create Date: 2025-08-06 14:29:24.691388
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5ae8240accb3"
|
||||
down_revision = "b558f51620b4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add research_type and research_plan columns to chat_message table
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_type", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_plan", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
# Create research_agent_iteration table
|
||||
op.create_table(
|
||||
"research_agent_iteration",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column(
|
||||
"primary_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("purpose", sa.String(), nullable=True),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"primary_question_id",
|
||||
"iteration_nr",
|
||||
name="_research_agent_iteration_unique_constraint",
|
||||
),
|
||||
)
|
||||
|
||||
# Create research_agent_iteration_sub_step table
|
||||
op.create_table(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column(
|
||||
"primary_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"parent_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("sub_step_instructions", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"sub_step_tool_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("tool.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.Column("sub_answer", sa.String(), nullable=True),
|
||||
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("claims", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["primary_question_id", "iteration_nr"],
|
||||
[
|
||||
"research_agent_iteration.primary_question_id",
|
||||
"research_agent_iteration.iteration_nr",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop tables in reverse order
|
||||
op.drop_table("research_agent_iteration_sub_step")
|
||||
op.drop_table("research_agent_iteration")
|
||||
|
||||
# Remove columns from chat_message table
|
||||
op.drop_column("chat_message", "research_plan")
|
||||
op.drop_column("chat_message", "research_type")
|
||||
@@ -0,0 +1,132 @@
|
||||
"""add file names to file connector config
|
||||
|
||||
Revision ID: 62c3a055a141
|
||||
Revises: 3fc5d75723b3
|
||||
Create Date: 2025-07-30 17:01:24.417551
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "62c3a055a141"
|
||||
down_revision = "3fc5d75723b3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
SKIP_FILE_NAME_MIGRATION = (
|
||||
os.environ.get("SKIP_FILE_NAME_MIGRATION", "true").lower() == "true"
|
||||
)
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if SKIP_FILE_NAME_MIGRATION:
|
||||
logger.info(
|
||||
"Skipping file name migration. Hint: set SKIP_FILE_NAME_MIGRATION=false to run this migration"
|
||||
)
|
||||
return
|
||||
logger.info("Running file name migration")
|
||||
# Get connection
|
||||
conn = op.get_bind()
|
||||
|
||||
# Get all FILE connectors with their configs
|
||||
file_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'FILE'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for connector_id, config in file_connectors:
|
||||
# Parse config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
# Get file_locations list
|
||||
file_locations = config.get("file_locations", [])
|
||||
|
||||
# Get display names for each file_id
|
||||
file_names = []
|
||||
for file_id in file_locations:
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT display_name
|
||||
FROM file_record
|
||||
WHERE file_id = :file_id
|
||||
"""
|
||||
),
|
||||
{"file_id": file_id},
|
||||
).fetchone()
|
||||
|
||||
if result:
|
||||
file_names.append(result[0])
|
||||
else:
|
||||
file_names.append(file_id) # Should not happen
|
||||
|
||||
# Add file_names to config
|
||||
new_config = dict(config)
|
||||
new_config["file_names"] = file_names
|
||||
|
||||
# Update the connector
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get connection
|
||||
conn = op.get_bind()
|
||||
|
||||
# Remove file_names from all FILE connectors
|
||||
file_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'FILE'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for connector_id, config in file_connectors:
|
||||
# Parse config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
# Remove file_names if it exists
|
||||
if "file_names" in config:
|
||||
new_config = dict(config)
|
||||
del new_config["file_names"]
|
||||
|
||||
# Update the connector
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"connector_id": connector_id,
|
||||
"new_config": json.dumps(new_config),
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""remove kg subtype from db
|
||||
|
||||
Revision ID: 65bc6e0f8500
|
||||
Revises: cec7ec36c505
|
||||
Create Date: 2025-06-13 10:04:27.705976
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "65bc6e0f8500"
|
||||
down_revision = "cec7ec36c505"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("kg_entity", "entity_class")
|
||||
op.drop_column("kg_entity", "entity_subtype")
|
||||
op.drop_column("kg_entity_extraction_staging", "entity_class")
|
||||
op.drop_column("kg_entity_extraction_staging", "entity_subtype")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"kg_entity_extraction_staging",
|
||||
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
op.add_column(
|
||||
"kg_entity_extraction_staging",
|
||||
sa.Column("entity_class", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
op.add_column(
|
||||
"kg_entity", sa.Column("entity_subtype", sa.String(), nullable=True, index=True)
|
||||
)
|
||||
op.add_column(
|
||||
"kg_entity", sa.Column("entity_class", sa.String(), nullable=True, index=True)
|
||||
)
|
||||
@@ -6,11 +6,8 @@ Create Date: 2024-04-15 01:36:02.952809
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "703313b75876"
|
||||
@@ -54,27 +51,10 @@ def upgrade() -> None:
|
||||
sa.PrimaryKeyConstraint("rate_limit_id", "user_group_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
settings_json = cast(str, get_kv_store().load("token_budget_settings"))
|
||||
settings = json.loads(settings_json)
|
||||
|
||||
is_enabled = settings.get("enable_token_budget", False)
|
||||
token_budget = settings.get("token_budget", -1)
|
||||
period_hours = settings.get("period_hours", -1)
|
||||
|
||||
if is_enabled and token_budget > 0 and period_hours > 0:
|
||||
op.execute(
|
||||
f"INSERT INTO token_rate_limit \
|
||||
(enabled, token_budget, period_hours, scope) VALUES \
|
||||
({is_enabled}, {token_budget}, {period_hours}, 'GLOBAL')"
|
||||
)
|
||||
|
||||
# Delete the dynamic config
|
||||
get_kv_store().delete("token_budget_settings")
|
||||
|
||||
except Exception:
|
||||
# Ignore if the dynamic config is not found
|
||||
pass
|
||||
# NOTE: rate limit settings used to be stored in the "token_budget_settings" key in the
|
||||
# KeyValueStore. This will now be lost. The KV store works differently than it used to
|
||||
# so the migration is fairly complicated and likely not worth it to support (pretty much
|
||||
# nobody will have it set)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
318
backend/alembic/versions/7b9b952abdf6_update_entities.py
Normal file
318
backend/alembic/versions/7b9b952abdf6_update_entities.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""update-entities
|
||||
|
||||
Revision ID: 7b9b952abdf6
|
||||
Revises: 36e9220ab794
|
||||
Create Date: 2025-06-23 20:24:08.139201
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7b9b952abdf6"
|
||||
down_revision = "36e9220ab794"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# new entity type metadata_attribute_conversion
|
||||
new_entity_type_conversion = {
|
||||
"LINEAR": {
|
||||
"team": {"name": "team", "keep": True, "implication_property": None},
|
||||
"state": {"name": "state", "keep": True, "implication_property": None},
|
||||
"priority": {
|
||||
"name": "priority",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"estimate": {
|
||||
"name": "estimate",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"created_at": {
|
||||
"name": "created_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"started_at": {
|
||||
"name": "started_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"completed_at": {
|
||||
"name": "completed_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"due_date": {
|
||||
"name": "due_date",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"creator": {
|
||||
"name": "creator",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_creator_of",
|
||||
},
|
||||
},
|
||||
"assignee": {
|
||||
"name": "assignee",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_assignee_of",
|
||||
},
|
||||
},
|
||||
},
|
||||
"JIRA": {
|
||||
"issuetype": {
|
||||
"name": "subtype",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"status": {"name": "status", "keep": True, "implication_property": None},
|
||||
"priority": {
|
||||
"name": "priority",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"project_name": {
|
||||
"name": "project",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"created": {
|
||||
"name": "created_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"updated": {
|
||||
"name": "updated_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"resolution_date": {
|
||||
"name": "completed_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"duedate": {"name": "due_date", "keep": True, "implication_property": None},
|
||||
"reporter_email": {
|
||||
"name": "creator",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_creator_of",
|
||||
},
|
||||
},
|
||||
"assignee_email": {
|
||||
"name": "assignee",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_assignee_of",
|
||||
},
|
||||
},
|
||||
"key": {"name": "key", "keep": True, "implication_property": None},
|
||||
"parent": {"name": "parent", "keep": True, "implication_property": None},
|
||||
},
|
||||
"GITHUB_PR": {
|
||||
"repo": {"name": "repository", "keep": True, "implication_property": None},
|
||||
"state": {"name": "state", "keep": True, "implication_property": None},
|
||||
"num_commits": {
|
||||
"name": "num_commits",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"num_files_changed": {
|
||||
"name": "num_files_changed",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"labels": {"name": "labels", "keep": True, "implication_property": None},
|
||||
"merged": {"name": "merged", "keep": True, "implication_property": None},
|
||||
"merged_at": {
|
||||
"name": "merged_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"closed_at": {
|
||||
"name": "closed_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"created_at": {
|
||||
"name": "created_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"updated_at": {
|
||||
"name": "updated_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"user": {
|
||||
"name": "creator",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_creator_of",
|
||||
},
|
||||
},
|
||||
"assignees": {
|
||||
"name": "assignees",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_assignee_of",
|
||||
},
|
||||
},
|
||||
},
|
||||
"GITHUB_ISSUE": {
|
||||
"repo": {"name": "repository", "keep": True, "implication_property": None},
|
||||
"state": {"name": "state", "keep": True, "implication_property": None},
|
||||
"labels": {"name": "labels", "keep": True, "implication_property": None},
|
||||
"closed_at": {
|
||||
"name": "closed_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"created_at": {
|
||||
"name": "created_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"updated_at": {
|
||||
"name": "updated_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"user": {
|
||||
"name": "creator",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_creator_of",
|
||||
},
|
||||
},
|
||||
"assignees": {
|
||||
"name": "assignees",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_assignee_of",
|
||||
},
|
||||
},
|
||||
},
|
||||
"FIREFLIES": {},
|
||||
"ACCOUNT": {},
|
||||
"OPPORTUNITY": {
|
||||
"name": {"name": "name", "keep": True, "implication_property": None},
|
||||
"stage_name": {"name": "stage", "keep": True, "implication_property": None},
|
||||
"type": {"name": "type", "keep": True, "implication_property": None},
|
||||
"amount": {"name": "amount", "keep": True, "implication_property": None},
|
||||
"fiscal_year": {
|
||||
"name": "fiscal_year",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"fiscal_quarter": {
|
||||
"name": "fiscal_quarter",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"is_closed": {
|
||||
"name": "is_closed",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"close_date": {
|
||||
"name": "close_date",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"probability": {
|
||||
"name": "close_probability",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"created_date": {
|
||||
"name": "created_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"last_modified_date": {
|
||||
"name": "updated_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"account": {
|
||||
"name": "account",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "ACCOUNT",
|
||||
"implied_relationship_name": "is_account_of",
|
||||
},
|
||||
},
|
||||
},
|
||||
"VENDOR": {},
|
||||
"EMPLOYEE": {},
|
||||
}
|
||||
|
||||
current_entity_types = conn.execute(
|
||||
sa.text("SELECT id_name, attributes from kg_entity_type")
|
||||
).all()
|
||||
for entity_type, attributes in current_entity_types:
|
||||
# delete removed entity types
|
||||
if entity_type not in new_entity_type_conversion:
|
||||
op.execute(
|
||||
sa.text(f"DELETE FROM kg_entity_type WHERE id_name = '{entity_type}'")
|
||||
)
|
||||
continue
|
||||
|
||||
# update entity type attributes
|
||||
if "metadata_attributes" in attributes:
|
||||
del attributes["metadata_attributes"]
|
||||
attributes["metadata_attribute_conversion"] = new_entity_type_conversion[
|
||||
entity_type
|
||||
]
|
||||
attributes_str = json.dumps(attributes).replace("'", "''")
|
||||
op.execute(
|
||||
sa.text(
|
||||
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'"
|
||||
f"WHERE id_name = '{entity_type}'"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
current_entity_types = conn.execute(
|
||||
sa.text("SELECT id_name, attributes from kg_entity_type")
|
||||
).all()
|
||||
for entity_type, attributes in current_entity_types:
|
||||
conversion = {}
|
||||
if "metadata_attribute_conversion" in attributes:
|
||||
conversion = attributes.pop("metadata_attribute_conversion")
|
||||
attributes["metadata_attributes"] = {
|
||||
attr: prop["name"] for attr, prop in conversion.items() if prop["keep"]
|
||||
}
|
||||
|
||||
attributes_str = json.dumps(attributes).replace("'", "''")
|
||||
op.execute(
|
||||
sa.text(
|
||||
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'"
|
||||
f"WHERE id_name = '{entity_type}'"
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,249 @@
|
||||
"""add_mcp_server_and_connection_config_models
|
||||
|
||||
Revision ID: 7ed603b64d5a
|
||||
Revises: b329d00a9ea6
|
||||
Create Date: 2025-07-28 17:35:59.900680
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7ed603b64d5a"
|
||||
down_revision = "b329d00a9ea6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create tables and columns for MCP Server support"""
|
||||
|
||||
# 1. MCP Server main table (no FK constraints yet to avoid circular refs)
|
||||
op.create_table(
|
||||
"mcp_server",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("owner", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column("server_url", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"auth_type",
|
||||
sa.Enum(
|
||||
MCPAuthenticationType,
|
||||
name="mcp_authentication_type",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("admin_connection_config_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# 2. MCP Connection Config table (can reference mcp_server now that it exists)
|
||||
op.create_table(
|
||||
"mcp_connection_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("mcp_server_id", sa.Integer(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False, default=""),
|
||||
sa.Column("config", sa.LargeBinary(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
|
||||
),
|
||||
)
|
||||
|
||||
# Helpful indexes
|
||||
op.create_index(
|
||||
"ix_mcp_connection_config_server_user",
|
||||
"mcp_connection_config",
|
||||
["mcp_server_id", "user_email"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_mcp_connection_config_user_email",
|
||||
"mcp_connection_config",
|
||||
["user_email"],
|
||||
)
|
||||
|
||||
# 3. Add the back-references from mcp_server to connection configs
|
||||
op.create_foreign_key(
|
||||
"mcp_server_admin_config_fk",
|
||||
"mcp_server",
|
||||
"mcp_connection_config",
|
||||
["admin_connection_config_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
# 4. Association / access-control tables
|
||||
op.create_table(
|
||||
"mcp_server__user",
|
||||
sa.Column("mcp_server_id", sa.Integer(), primary_key=True),
|
||||
sa.Column("user_id", sa.UUID(), primary_key=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"mcp_server__user_group",
|
||||
sa.Column("mcp_server_id", sa.Integer(), primary_key=True),
|
||||
sa.Column("user_group_id", sa.Integer(), primary_key=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_group_id"], ["user_group.id"]),
|
||||
)
|
||||
|
||||
# 5. Update existing `tool` table – allow tools to belong to an MCP server
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column("mcp_server_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
# Add column for MCP tool input schema
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column("mcp_input_schema", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"tool_mcp_server_fk",
|
||||
"tool",
|
||||
"mcp_server",
|
||||
["mcp_server_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# 6. Update persona__tool foreign keys to cascade delete
|
||||
# This ensures that when a tool is deleted (including via MCP server deletion),
|
||||
# the corresponding persona__tool rows are also deleted
|
||||
op.drop_constraint(
|
||||
"persona__tool_tool_id_fkey", "persona__tool", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"persona__tool_persona_id_fkey", "persona__tool", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"persona__tool_persona_id_fkey",
|
||||
"persona__tool",
|
||||
"persona",
|
||||
["persona_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"persona__tool_tool_id_fkey",
|
||||
"persona__tool",
|
||||
"tool",
|
||||
["tool_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# 7. Update research_agent_iteration_sub_step foreign key to SET NULL on delete
|
||||
# This ensures that when a tool is deleted, the sub_step_tool_id is set to NULL
|
||||
# instead of causing a foreign key constraint violation
|
||||
op.drop_constraint(
|
||||
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
"tool",
|
||||
["sub_step_tool_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop all MCP-related tables / columns"""
|
||||
|
||||
# # # 1. Drop FK & columns from tool
|
||||
# op.drop_constraint("tool_mcp_server_fk", "tool", type_="foreignkey")
|
||||
op.execute("DELETE FROM tool WHERE mcp_server_id IS NOT NULL")
|
||||
|
||||
op.drop_constraint(
|
||||
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
"tool",
|
||||
["sub_step_tool_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Restore original persona__tool foreign keys (without CASCADE)
|
||||
op.drop_constraint(
|
||||
"persona__tool_persona_id_fkey", "persona__tool", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"persona__tool_tool_id_fkey", "persona__tool", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"persona__tool_persona_id_fkey",
|
||||
"persona__tool",
|
||||
"persona",
|
||||
["persona_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"persona__tool_tool_id_fkey",
|
||||
"persona__tool",
|
||||
"tool",
|
||||
["tool_id"],
|
||||
["id"],
|
||||
)
|
||||
op.drop_column("tool", "mcp_input_schema")
|
||||
op.drop_column("tool", "mcp_server_id")
|
||||
|
||||
# 2. Drop association tables
|
||||
op.drop_table("mcp_server__user_group")
|
||||
op.drop_table("mcp_server__user")
|
||||
|
||||
# 3. Drop FK from mcp_server to connection configs
|
||||
op.drop_constraint("mcp_server_admin_config_fk", "mcp_server", type_="foreignkey")
|
||||
|
||||
# 4. Drop connection config indexes & table
|
||||
op.drop_index(
|
||||
"ix_mcp_connection_config_user_email", table_name="mcp_connection_config"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_mcp_connection_config_server_user", table_name="mcp_connection_config"
|
||||
)
|
||||
op.drop_table("mcp_connection_config")
|
||||
|
||||
# 5. Finally drop mcp_server table
|
||||
op.drop_table("mcp_server")
|
||||
@@ -0,0 +1,38 @@
|
||||
"""drop include citations
|
||||
|
||||
Revision ID: 8818cf73fa1a
|
||||
Revises: 7ed603b64d5a
|
||||
Create Date: 2025-09-02 19:43:50.060680
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8818cf73fa1a"
|
||||
down_revision = "7ed603b64d5a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("prompt", "include_citations")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"prompt",
|
||||
sa.Column(
|
||||
"include_citations",
|
||||
sa.BOOLEAN(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
# Set include_citations based on prompt name: FALSE for ImageGeneration, TRUE for others
|
||||
op.execute(
|
||||
sa.text(
|
||||
"UPDATE prompt SET include_citations = CASE WHEN name = 'ImageGeneration' THEN FALSE ELSE TRUE END"
|
||||
)
|
||||
)
|
||||
341
backend/alembic/versions/90e3b9af7da4_tag_fix.py
Normal file
341
backend/alembic/versions/90e3b9af7da4_tag_fix.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""tag-fix
|
||||
|
||||
Revision ID: 90e3b9af7da4
|
||||
Revises: 62c3a055a141
|
||||
Create Date: 2025-08-01 20:58:14.607624
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from typing import cast
|
||||
from typing import Generator
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.db.search_settings import SearchSettings
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "90e3b9af7da4"
|
||||
down_revision = "62c3a055a141"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
SKIP_TAG_FIX = os.environ.get("SKIP_TAG_FIX", "true").lower() == "true"
|
||||
|
||||
# override for cloud
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
SKIP_TAG_FIX = True
|
||||
|
||||
|
||||
def set_is_list_for_known_tags() -> None:
|
||||
"""
|
||||
Sets is_list to true for all tags that are known to be lists.
|
||||
"""
|
||||
LIST_METADATA: list[tuple[str, str]] = [
|
||||
("CLICKUP", "tags"),
|
||||
("CONFLUENCE", "labels"),
|
||||
("DISCOURSE", "tags"),
|
||||
("FRESHDESK", "emails"),
|
||||
("GITHUB", "assignees"),
|
||||
("GITHUB", "labels"),
|
||||
("GURU", "tags"),
|
||||
("GURU", "folders"),
|
||||
("HUBSPOT", "associated_contact_ids"),
|
||||
("HUBSPOT", "associated_company_ids"),
|
||||
("HUBSPOT", "associated_deal_ids"),
|
||||
("HUBSPOT", "associated_ticket_ids"),
|
||||
("JIRA", "labels"),
|
||||
("MEDIAWIKI", "categories"),
|
||||
("ZENDESK", "labels"),
|
||||
("ZENDESK", "content_tags"),
|
||||
]
|
||||
|
||||
bind = op.get_bind()
|
||||
for source, key in LIST_METADATA:
|
||||
bind.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
UPDATE tag
|
||||
SET is_list = true
|
||||
WHERE tag_key = '{key}'
|
||||
AND source = '{source}'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_is_list_for_list_tags() -> None:
|
||||
"""
|
||||
Sets is_list to true for all tags which have multiple values for a given
|
||||
document, key, and source triplet. This only works if we remove old tags
|
||||
from the database.
|
||||
"""
|
||||
bind = op.get_bind()
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tag
|
||||
SET is_list = true
|
||||
FROM (
|
||||
SELECT DISTINCT tag.tag_key, tag.source
|
||||
FROM tag
|
||||
JOIN document__tag ON tag.id = document__tag.tag_id
|
||||
GROUP BY tag.tag_key, tag.source, document__tag.document_id
|
||||
HAVING count(*) > 1
|
||||
) AS list_tags
|
||||
WHERE tag.tag_key = list_tags.tag_key
|
||||
AND tag.source = list_tags.source
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def log_list_tags() -> None:
|
||||
bind = op.get_bind()
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT DISTINCT source, tag_key
|
||||
FROM tag
|
||||
WHERE is_list
|
||||
ORDER BY source, tag_key
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
logger.info(
|
||||
"List tags:\n" + "\n".join(f" {source}: {key}" for source, key in result)
|
||||
)
|
||||
|
||||
|
||||
def remove_old_tags() -> None:
|
||||
"""
|
||||
Removes old tags from the database.
|
||||
Previously, there was a bug where if a document got indexed with a tag and then
|
||||
the document got reindexed, the old tag would not be removed.
|
||||
This function removes those old tags by comparing it against the tags in vespa.
|
||||
"""
|
||||
current_search_settings, future_search_settings = active_search_settings()
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings, future_search_settings
|
||||
)
|
||||
|
||||
# Get the index name
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
else:
|
||||
# Default index name if we can't get it from the document_index
|
||||
index_name = "danswer_index"
|
||||
|
||||
for batch in _get_batch_documents_with_multiple_tags():
|
||||
n_deleted = 0
|
||||
|
||||
for document_id in batch:
|
||||
true_metadata = _get_vespa_metadata(document_id, index_name)
|
||||
tags = _get_document_tags(document_id)
|
||||
|
||||
# identify document__tags to delete
|
||||
to_delete: list[str] = []
|
||||
for tag_id, tag_key, tag_value in tags:
|
||||
true_val = true_metadata.get(tag_key, "")
|
||||
if (isinstance(true_val, list) and tag_value not in true_val) or (
|
||||
isinstance(true_val, str) and tag_value != true_val
|
||||
):
|
||||
to_delete.append(str(tag_id))
|
||||
|
||||
if not to_delete:
|
||||
continue
|
||||
|
||||
# delete old document__tags
|
||||
bind = op.get_bind()
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
DELETE FROM document__tag
|
||||
WHERE document_id = '{document_id}'
|
||||
AND tag_id IN ({','.join(to_delete)})
|
||||
"""
|
||||
)
|
||||
)
|
||||
n_deleted += result.rowcount
|
||||
logger.info(f"Processed {len(batch)} documents and deleted {n_deleted} tags")
|
||||
|
||||
|
||||
def active_search_settings() -> tuple[SearchSettings, SearchSettings | None]:
|
||||
result = op.get_bind().execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT * FROM search_settings WHERE status = 'PRESENT' ORDER BY id DESC LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
search_settings_fetch = result.fetchall()
|
||||
search_settings = (
|
||||
SearchSettings(**search_settings_fetch[0]._asdict())
|
||||
if search_settings_fetch
|
||||
else None
|
||||
)
|
||||
|
||||
result2 = op.get_bind().execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT * FROM search_settings WHERE status = 'FUTURE' ORDER BY id DESC LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
search_settings_future_fetch = result2.fetchall()
|
||||
search_settings_future = (
|
||||
SearchSettings(**search_settings_future_fetch[0]._asdict())
|
||||
if search_settings_future_fetch
|
||||
else None
|
||||
)
|
||||
|
||||
if not isinstance(search_settings, SearchSettings):
|
||||
raise RuntimeError(
|
||||
"current search settings is of type " + str(type(search_settings))
|
||||
)
|
||||
if (
|
||||
not isinstance(search_settings_future, SearchSettings)
|
||||
and search_settings_future is not None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"future search settings is of type " + str(type(search_settings_future))
|
||||
)
|
||||
|
||||
return search_settings, search_settings_future
|
||||
|
||||
|
||||
def _get_batch_documents_with_multiple_tags(
|
||||
batch_size: int = 128,
|
||||
) -> Generator[list[str], None, None]:
|
||||
"""
|
||||
Returns a list of document ids which contain a one to many tag.
|
||||
The document may either contain a list metadata value, or may contain leftover
|
||||
old tags from reindexing.
|
||||
"""
|
||||
offset_clause = ""
|
||||
bind = op.get_bind()
|
||||
|
||||
while True:
|
||||
batch = bind.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
SELECT DISTINCT document__tag.document_id
|
||||
FROM tag
|
||||
JOIN document__tag ON tag.id = document__tag.tag_id
|
||||
GROUP BY tag.tag_key, tag.source, document__tag.document_id
|
||||
HAVING count(*) > 1 {offset_clause}
|
||||
ORDER BY document__tag.document_id
|
||||
LIMIT {batch_size}
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
if not batch:
|
||||
break
|
||||
doc_ids = [document_id for document_id, in batch]
|
||||
yield doc_ids
|
||||
offset_clause = f"AND document__tag.document_id > '{doc_ids[-1]}'"
|
||||
|
||||
|
||||
def _get_vespa_metadata(
|
||||
document_id: str, index_name: str
|
||||
) -> dict[str, str | list[str]]:
|
||||
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
|
||||
# Document-Selector language
|
||||
selection = (
|
||||
f"{index_name}.document_id=='{document_id}' and {index_name}.chunk_id==0"
|
||||
)
|
||||
|
||||
params: dict[str, str | int] = {
|
||||
"selection": selection,
|
||||
"wantedDocumentCount": 1,
|
||||
"fieldSet": f"{index_name}:metadata",
|
||||
}
|
||||
|
||||
with get_vespa_http_client() as client:
|
||||
resp = client.get(url, params=params)
|
||||
resp.raise_for_status()
|
||||
|
||||
docs = resp.json().get("documents", [])
|
||||
if not docs:
|
||||
raise RuntimeError(f"No chunk-0 found for document {document_id}")
|
||||
|
||||
# for some reason, metadata is a string
|
||||
metadata = docs[0]["fields"]["metadata"]
|
||||
return json.loads(metadata)
|
||||
|
||||
|
||||
def _get_document_tags(document_id: str) -> list[tuple[int, str, str]]:
|
||||
bind = op.get_bind()
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
SELECT tag.id, tag.tag_key, tag.tag_value
|
||||
FROM tag
|
||||
JOIN document__tag ON tag.id = document__tag.tag_id
|
||||
WHERE document__tag.document_id = '{document_id}'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
return cast(list[tuple[int, str, str]], result)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tag",
|
||||
sa.Column("is_list", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.drop_constraint(
|
||||
constraint_name="_tag_key_value_source_uc",
|
||||
table_name="tag",
|
||||
type_="unique",
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
constraint_name="_tag_key_value_source_list_uc",
|
||||
table_name="tag",
|
||||
columns=["tag_key", "tag_value", "source", "is_list"],
|
||||
)
|
||||
set_is_list_for_known_tags()
|
||||
|
||||
if SKIP_TAG_FIX:
|
||||
logger.warning(
|
||||
"Skipping removal of old tags. "
|
||||
"This can cause issues when using the knowledge graph, or "
|
||||
"when filtering for documents by tags."
|
||||
)
|
||||
log_list_tags()
|
||||
return
|
||||
|
||||
remove_old_tags()
|
||||
set_is_list_for_list_tags()
|
||||
|
||||
# debug
|
||||
log_list_tags()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# the migration adds and populates the is_list column, and removes old bugged tags
|
||||
# there isn't a point in adding back the bugged tags, so we just drop the column
|
||||
op.drop_constraint(
|
||||
constraint_name="_tag_key_value_source_list_uc",
|
||||
table_name="tag",
|
||||
type_="unique",
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
constraint_name="_tag_key_value_source_uc",
|
||||
table_name="tag",
|
||||
columns=["tag_key", "tag_value", "source"],
|
||||
)
|
||||
op.drop_column("tag", "is_list")
|
||||
@@ -0,0 +1,225 @@
|
||||
"""merge prompt into persona
|
||||
|
||||
Revision ID: abbfec3a5ac5
|
||||
Revises: 8818cf73fa1a
|
||||
Create Date: 2024-12-19 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abbfec3a5ac5"
|
||||
down_revision = "8818cf73fa1a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
MAX_PROMPT_LENGTH = 5_000_000
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""NOTE: Prompts without any Personas will just be lost."""
|
||||
# Step 1: Add new columns to persona table (only if they don't exist)
|
||||
|
||||
# Check if columns exist before adding them
|
||||
connection = op.get_bind()
|
||||
inspector = sa.inspect(connection)
|
||||
existing_columns = [col["name"] for col in inspector.get_columns("persona")]
|
||||
|
||||
if "system_prompt" not in existing_columns:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
if "task_prompt" not in existing_columns:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
if "datetime_aware" not in existing_columns:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"datetime_aware", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
)
|
||||
|
||||
# Step 2: Migrate data from prompt table to persona table (only if tables exist)
|
||||
existing_tables = inspector.get_table_names()
|
||||
|
||||
if "prompt" in existing_tables and "persona__prompt" in existing_tables:
|
||||
# For personas that have associated prompts, copy the prompt data
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET
|
||||
system_prompt = p.system_prompt,
|
||||
task_prompt = p.task_prompt,
|
||||
datetime_aware = p.datetime_aware
|
||||
FROM (
|
||||
-- Get the first prompt for each persona (in case there are multiple)
|
||||
SELECT DISTINCT ON (pp.persona_id)
|
||||
pp.persona_id,
|
||||
pr.system_prompt,
|
||||
pr.task_prompt,
|
||||
pr.datetime_aware
|
||||
FROM persona__prompt pp
|
||||
JOIN prompt pr ON pp.prompt_id = pr.id
|
||||
) p
|
||||
WHERE persona.id = p.persona_id
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 3: Update chat_message references
|
||||
# Since chat messages referenced prompt_id, we need to update them to use persona_id
|
||||
# This is complex as we need to map from prompt_id to persona_id
|
||||
|
||||
# Check if chat_message has prompt_id column
|
||||
chat_message_columns = [
|
||||
col["name"] for col in inspector.get_columns("chat_message")
|
||||
]
|
||||
if "prompt_id" in chat_message_columns:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_message
|
||||
DROP CONSTRAINT IF EXISTS chat_message__prompt_fk
|
||||
"""
|
||||
)
|
||||
op.drop_column("chat_message", "prompt_id")
|
||||
|
||||
# Step 4: Handle personas without prompts - set default values if needed (always run this)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET
|
||||
system_prompt = COALESCE(system_prompt, ''),
|
||||
task_prompt = COALESCE(task_prompt, '')
|
||||
WHERE system_prompt IS NULL OR task_prompt IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 5: Drop the persona__prompt association table (if it exists)
|
||||
if "persona__prompt" in existing_tables:
|
||||
op.drop_table("persona__prompt")
|
||||
|
||||
# Step 6: Drop the prompt table (if it exists)
|
||||
if "prompt" in existing_tables:
|
||||
op.drop_table("prompt")
|
||||
|
||||
# Step 7: Make system_prompt and task_prompt non-nullable after migration (only if they exist)
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
nullable=False,
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
nullable=False,
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Step 1: Recreate the prompt table
|
||||
op.create_table(
|
||||
"prompt",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=False),
|
||||
sa.Column("system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False),
|
||||
sa.Column("task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False),
|
||||
sa.Column(
|
||||
"datetime_aware", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
sa.Column(
|
||||
"default_prompt", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Step 2: Recreate the persona__prompt association table
|
||||
op.create_table(
|
||||
"persona__prompt",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("prompt_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["prompt_id"],
|
||||
["prompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "prompt_id"),
|
||||
)
|
||||
|
||||
# Step 3: Migrate data back from persona to prompt table
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO prompt (
|
||||
name,
|
||||
description,
|
||||
system_prompt,
|
||||
task_prompt,
|
||||
datetime_aware,
|
||||
default_prompt,
|
||||
deleted,
|
||||
user_id
|
||||
)
|
||||
SELECT
|
||||
CONCAT('Prompt for ', name),
|
||||
description,
|
||||
system_prompt,
|
||||
task_prompt,
|
||||
datetime_aware,
|
||||
is_default_persona,
|
||||
deleted,
|
||||
user_id
|
||||
FROM persona
|
||||
WHERE system_prompt IS NOT NULL AND system_prompt != ''
|
||||
RETURNING id, name
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 4: Re-establish persona__prompt relationships
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO persona__prompt (persona_id, prompt_id)
|
||||
SELECT
|
||||
p.id as persona_id,
|
||||
pr.id as prompt_id
|
||||
FROM persona p
|
||||
JOIN prompt pr ON pr.name = CONCAT('Prompt for ', p.name)
|
||||
WHERE p.system_prompt IS NOT NULL AND p.system_prompt != ''
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 5: Add prompt_id column back to chat_message
|
||||
op.add_column("chat_message", sa.Column("prompt_id", sa.Integer(), nullable=True))
|
||||
|
||||
# Step 6: Re-establish foreign key constraint
|
||||
op.create_foreign_key(
|
||||
"chat_message__prompt_fk", "chat_message", "prompt", ["prompt_id"], ["id"]
|
||||
)
|
||||
|
||||
# Step 7: Remove columns from persona table
|
||||
op.drop_column("persona", "datetime_aware")
|
||||
op.drop_column("persona", "task_prompt")
|
||||
op.drop_column("persona", "system_prompt")
|
||||
@@ -0,0 +1,38 @@
|
||||
"""Adding assistant-specific user preferences
|
||||
|
||||
Revision ID: b329d00a9ea6
|
||||
Revises: f9b8c7d6e5a4
|
||||
Create Date: 2025-08-26 23:14:44.592985
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b329d00a9ea6"
|
||||
down_revision = "f9b8c7d6e5a4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"assistant__user_specific_config",
|
||||
sa.Column("assistant_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("disabled_tool_ids", postgresql.ARRAY(sa.Integer()), nullable=False),
|
||||
sa.ForeignKeyConstraint(["assistant_id"], ["persona.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("assistant_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("assistant__user_specific_config")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Pause finished user file connectors
|
||||
|
||||
Revision ID: b558f51620b4
|
||||
Revises: 90e3b9af7da4
|
||||
Create Date: 2025-08-15 17:17:02.456704
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b558f51620b4"
|
||||
down_revision = "90e3b9af7da4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Set all user file connector credential pairs with ACTIVE status to PAUSED
|
||||
# This ensures user files don't continue to run indexing tasks after processing
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector_credential_pair
|
||||
SET status = 'PAUSED'
|
||||
WHERE is_user_file = true
|
||||
AND status = 'ACTIVE'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,43 @@
|
||||
"""adjust prompt length
|
||||
|
||||
Revision ID: b7ec9b5b505f
|
||||
Revises: abbfec3a5ac5
|
||||
Create Date: 2025-09-10 18:51:15.629197
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7ec9b5b505f"
|
||||
down_revision = "abbfec3a5ac5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
MAX_PROMPT_LENGTH = 5_000_000
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# NOTE: need to run this since the previous migration PREVIOUSLY set the length to 8000
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
existing_type=sa.String(length=8000),
|
||||
type_=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
existing_type=sa.String(length=8000),
|
||||
type_=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Downgrade not necessary
|
||||
pass
|
||||
@@ -0,0 +1,147 @@
|
||||
"""migrate_agent_sub_questions_to_research_iterations
|
||||
|
||||
Revision ID: bd7c3bf8beba
|
||||
Revises: f8a9b2c3d4e5
|
||||
Create Date: 2025-08-18 11:33:27.098287
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bd7c3bf8beba"
|
||||
down_revision = "f8a9b2c3d4e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get connection to execute raw SQL
|
||||
connection = op.get_bind()
|
||||
|
||||
# First, insert data into research_agent_iteration table
|
||||
# This creates one iteration record per primary_question_id using the earliest time_created
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO research_agent_iteration (primary_question_id, created_at, iteration_nr, purpose, reasoning)
|
||||
SELECT
|
||||
primary_question_id,
|
||||
MIN(time_created) as created_at,
|
||||
1 as iteration_nr,
|
||||
'Generating and researching subquestions' as purpose,
|
||||
'(No previous reasoning)' as reasoning
|
||||
FROM agent__sub_question
|
||||
JOIN chat_message on agent__sub_question.primary_question_id = chat_message.id
|
||||
WHERE primary_question_id IS NOT NULL
|
||||
AND chat_message.is_agentic = true
|
||||
GROUP BY primary_question_id
|
||||
ON CONFLICT DO NOTHING;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Then, insert data into research_agent_iteration_sub_step table
|
||||
# This migrates each sub-question as a sub-step
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO research_agent_iteration_sub_step (
|
||||
primary_question_id,
|
||||
iteration_nr,
|
||||
iteration_sub_step_nr,
|
||||
created_at,
|
||||
sub_step_instructions,
|
||||
sub_step_tool_id,
|
||||
sub_answer,
|
||||
cited_doc_results
|
||||
)
|
||||
SELECT
|
||||
primary_question_id,
|
||||
1 as iteration_nr,
|
||||
level_question_num as iteration_sub_step_nr,
|
||||
time_created as created_at,
|
||||
sub_question as sub_step_instructions,
|
||||
1 as sub_step_tool_id,
|
||||
sub_answer,
|
||||
sub_question_doc_results as cited_doc_results
|
||||
FROM agent__sub_question
|
||||
JOIN chat_message on agent__sub_question.primary_question_id = chat_message.id
|
||||
WHERE chat_message.is_agentic = true
|
||||
AND primary_question_id IS NOT NULL
|
||||
ON CONFLICT DO NOTHING;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Update chat_message records: set legacy agentic type and answer purpose for existing agentic messages
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET research_answer_purpose = 'ANSWER'
|
||||
WHERE is_agentic = true
|
||||
AND research_type IS NULL and
|
||||
message_type = 'ASSISTANT';
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET research_type = 'LEGACY_AGENTIC'
|
||||
WHERE is_agentic = true
|
||||
AND research_type IS NULL;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get connection to execute raw SQL
|
||||
connection = op.get_bind()
|
||||
|
||||
# Note: This downgrade removes all research agent iteration data
|
||||
# There's no way to perfectly restore the original agent__sub_question data
|
||||
# if it was deleted after this migration
|
||||
|
||||
# Delete all research_agent_iteration_sub_step records that were migrated
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM research_agent_iteration_sub_step
|
||||
USING chat_message
|
||||
WHERE research_agent_iteration_sub_step.primary_question_id = chat_message.id
|
||||
AND chat_message.research_type = 'LEGACY_AGENTIC';
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Delete all research_agent_iteration records that were migrated
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM research_agent_iteration
|
||||
USING chat_message
|
||||
WHERE research_agent_iteration.primary_question_id = chat_message.id
|
||||
AND chat_message.research_type = 'LEGACY_AGENTIC';
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Revert chat_message updates: clear research fields for legacy agentic messages
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET research_type = NULL,
|
||||
research_answer_purpose = NULL
|
||||
WHERE is_agentic = true
|
||||
AND research_type = 'LEGACY_AGENTIC'
|
||||
AND message_type = 'ASSISTANT';
|
||||
"""
|
||||
)
|
||||
)
|
||||
315
backend/alembic/versions/c9e2cd766c29_add_s3_file_store_table.py
Normal file
315
backend/alembic/versions/c9e2cd766c29_add_s3_file_store_table.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""modify_file_store_for_external_storage
|
||||
|
||||
Revision ID: c9e2cd766c29
|
||||
Revises: 03bf8be6b53a
|
||||
Create Date: 2025-06-13 14:02:09.867679
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from typing import cast, Any
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from onyx.db._deprecated.pg_file_store import delete_lobj_by_id, read_lobj
|
||||
from onyx.file_store.file_store import get_s3_file_store
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c9e2cd766c29"
|
||||
down_revision = "03bf8be6b53a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
try:
|
||||
# Modify existing file_store table to support external storage
|
||||
op.rename_table("file_store", "file_record")
|
||||
|
||||
# Make lobj_oid nullable (for external storage files)
|
||||
op.alter_column("file_record", "lobj_oid", nullable=True)
|
||||
|
||||
# Add external storage columns with generic names
|
||||
op.add_column(
|
||||
"file_record", sa.Column("bucket_name", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"file_record", sa.Column("object_key", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
# Add timestamps for tracking
|
||||
op.add_column(
|
||||
"file_record",
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"file_record",
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
op.alter_column("file_record", "file_name", new_column_name="file_id")
|
||||
except Exception as e:
|
||||
if "does not exist" in str(e) or 'relation "file_store" does not exist' in str(
|
||||
e
|
||||
):
|
||||
print(
|
||||
f"Ran into error - {e}. Likely means we had a partial success in the past, continuing..."
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
print(
|
||||
"External storage configured - migrating files from PostgreSQL to external storage..."
|
||||
)
|
||||
# if we fail midway through this, we'll have a partial success. Running the migration
|
||||
# again should allow us to continue.
|
||||
_migrate_files_to_external_storage()
|
||||
print("File migration completed successfully!")
|
||||
|
||||
# Remove lobj_oid column
|
||||
op.drop_column("file_record", "lobj_oid")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Revert schema changes and migrate files from external storage back to PostgreSQL large objects."""
|
||||
|
||||
print(
|
||||
"Reverting to PostgreSQL-backed file store – migrating files from external storage …"
|
||||
)
|
||||
|
||||
# 1. Ensure `lobj_oid` exists on the current `file_record` table (nullable for now).
|
||||
op.add_column("file_record", sa.Column("lobj_oid", sa.Integer(), nullable=True))
|
||||
|
||||
# 2. Move content from external storage back into PostgreSQL large objects (table is still
|
||||
# called `file_record` so application code continues to work during the copy).
|
||||
try:
|
||||
_migrate_files_to_postgres()
|
||||
except Exception:
|
||||
print("Error during downgrade migration, rolling back …")
|
||||
op.drop_column("file_record", "lobj_oid")
|
||||
raise
|
||||
|
||||
# 3. After migration every row should now have `lobj_oid` populated – mark NOT NULL.
|
||||
op.alter_column("file_record", "lobj_oid", nullable=False)
|
||||
|
||||
# 4. Remove columns that are only relevant to external storage.
|
||||
op.drop_column("file_record", "updated_at")
|
||||
op.drop_column("file_record", "created_at")
|
||||
op.drop_column("file_record", "object_key")
|
||||
op.drop_column("file_record", "bucket_name")
|
||||
|
||||
# 5. Rename `file_id` back to `file_name` (still on `file_record`).
|
||||
op.alter_column("file_record", "file_id", new_column_name="file_name")
|
||||
|
||||
# 6. Finally, rename the table back to its original name expected by the legacy codebase.
|
||||
op.rename_table("file_record", "file_store")
|
||||
|
||||
print(
|
||||
"Downgrade migration completed – files are now stored inside PostgreSQL again."
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper: migrate from external storage (S3/MinIO) back into PostgreSQL large objects
|
||||
|
||||
|
||||
def _migrate_files_to_postgres() -> None:
|
||||
"""Move any files whose content lives in external S3-compatible storage back into PostgreSQL.
|
||||
|
||||
The logic mirrors *inverse* of `_migrate_files_to_external_storage` used on upgrade.
|
||||
"""
|
||||
|
||||
# Obtain DB session from Alembic context
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
# Fetch rows that have external storage pointers (bucket/object_key not NULL)
|
||||
result = session.execute(
|
||||
text(
|
||||
"SELECT file_id, bucket_name, object_key FROM file_record "
|
||||
"WHERE bucket_name IS NOT NULL AND object_key IS NOT NULL"
|
||||
)
|
||||
)
|
||||
|
||||
files_to_migrate = [row[0] for row in result.fetchall()]
|
||||
total_files = len(files_to_migrate)
|
||||
|
||||
if total_files == 0:
|
||||
print("No files found in external storage to migrate back to PostgreSQL.")
|
||||
return
|
||||
|
||||
print(f"Found {total_files} files to migrate back to PostgreSQL large objects.")
|
||||
|
||||
_set_tenant_contextvar(session)
|
||||
migrated_count = 0
|
||||
|
||||
# only create external store if we have files to migrate. This line
|
||||
# makes it so we need to have S3/MinIO configured to run this migration.
|
||||
external_store = get_s3_file_store()
|
||||
|
||||
for i, file_id in enumerate(files_to_migrate, 1):
|
||||
print(f"Migrating file {i}/{total_files}: {file_id}")
|
||||
|
||||
# Read file content from external storage (always binary)
|
||||
try:
|
||||
file_io = external_store.read_file(
|
||||
file_id=file_id, mode="b", use_tempfile=True
|
||||
)
|
||||
file_io.seek(0)
|
||||
|
||||
# Import lazily to avoid circular deps at Alembic runtime
|
||||
from onyx.db._deprecated.pg_file_store import (
|
||||
create_populate_lobj,
|
||||
) # noqa: E402
|
||||
|
||||
# Create new Postgres large object and populate it
|
||||
lobj_oid = create_populate_lobj(content=file_io, db_session=session)
|
||||
|
||||
# Update DB row: set lobj_oid, clear bucket/object_key
|
||||
session.execute(
|
||||
text(
|
||||
"UPDATE file_record SET lobj_oid = :lobj_oid, bucket_name = NULL, "
|
||||
"object_key = NULL WHERE file_id = :file_id"
|
||||
),
|
||||
{"lobj_oid": lobj_oid, "file_id": file_id},
|
||||
)
|
||||
except ClientError as e:
|
||||
if "NoSuchKey" in str(e):
|
||||
print(
|
||||
f"File {file_id} not found in external storage. Deleting from database."
|
||||
)
|
||||
session.execute(
|
||||
text("DELETE FROM file_record WHERE file_id = :file_id"),
|
||||
{"file_id": file_id},
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
migrated_count += 1
|
||||
print(f"✓ Successfully migrated file {i}/{total_files}: {file_id}")
|
||||
|
||||
# Flush the SQLAlchemy session so statements are sent to the DB, but **do not**
|
||||
# commit the transaction. The surrounding Alembic migration will commit once
|
||||
# the *entire* downgrade succeeds. This keeps the whole downgrade atomic and
|
||||
# avoids leaving the database in a partially-migrated state if a later schema
|
||||
# operation fails.
|
||||
session.flush()
|
||||
|
||||
print(
|
||||
f"Migration back to PostgreSQL completed: {migrated_count} files staged for commit."
|
||||
)
|
||||
|
||||
|
||||
def _migrate_files_to_external_storage() -> None:
|
||||
"""Migrate files from PostgreSQL large objects to external storage"""
|
||||
# Get database session
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
external_store = get_s3_file_store()
|
||||
|
||||
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
|
||||
result = session.execute(
|
||||
text(
|
||||
"SELECT file_id FROM file_record WHERE lobj_oid IS NOT NULL "
|
||||
"AND bucket_name IS NULL AND object_key IS NULL"
|
||||
)
|
||||
)
|
||||
|
||||
files_to_migrate = [row[0] for row in result.fetchall()]
|
||||
total_files = len(files_to_migrate)
|
||||
|
||||
if total_files == 0:
|
||||
print("No files found in PostgreSQL storage to migrate.")
|
||||
return
|
||||
|
||||
# might need to move this above the if statement when creating a new multi-tenant
|
||||
# system. VERY extreme edge case.
|
||||
external_store.initialize()
|
||||
print(f"Found {total_files} files to migrate from PostgreSQL to external storage.")
|
||||
|
||||
_set_tenant_contextvar(session)
|
||||
migrated_count = 0
|
||||
|
||||
for i, file_id in enumerate(files_to_migrate, 1):
|
||||
print(f"Migrating file {i}/{total_files}: {file_id}")
|
||||
|
||||
# Read file record to get metadata
|
||||
file_record = session.execute(
|
||||
text("SELECT * FROM file_record WHERE file_id = :file_id"),
|
||||
{"file_id": file_id},
|
||||
).fetchone()
|
||||
|
||||
if file_record is None:
|
||||
print(f"File {file_id} not found in PostgreSQL storage.")
|
||||
continue
|
||||
|
||||
lobj_id = cast(int, file_record.lobj_oid) # type: ignore
|
||||
file_metadata = cast(Any, file_record.file_metadata) # type: ignore
|
||||
|
||||
# Read file content from PostgreSQL
|
||||
try:
|
||||
file_content = read_lobj(
|
||||
lobj_id, db_session=session, mode="b", use_tempfile=True
|
||||
)
|
||||
except Exception as e:
|
||||
if "large object" in str(e) and "does not exist" in str(e):
|
||||
print(f"File {file_id} not found in PostgreSQL storage.")
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
# Handle file_metadata type conversion
|
||||
file_metadata = None
|
||||
if file_metadata is not None:
|
||||
if isinstance(file_metadata, dict):
|
||||
file_metadata = file_metadata
|
||||
else:
|
||||
# Convert other types to dict if possible, otherwise None
|
||||
try:
|
||||
file_metadata = dict(file_record.file_metadata) # type: ignore
|
||||
except (TypeError, ValueError):
|
||||
file_metadata = None
|
||||
|
||||
# Save to external storage (this will handle the database record update and cleanup)
|
||||
# NOTE: this WILL .commit() the transaction.
|
||||
external_store.save_file(
|
||||
file_id=file_id,
|
||||
content=file_content,
|
||||
display_name=file_record.display_name,
|
||||
file_origin=file_record.file_origin,
|
||||
file_type=file_record.file_type,
|
||||
file_metadata=file_metadata,
|
||||
)
|
||||
delete_lobj_by_id(lobj_id, db_session=session)
|
||||
|
||||
migrated_count += 1
|
||||
print(f"✓ Successfully migrated file {i}/{total_files}: {file_id}")
|
||||
|
||||
# See note above – flush but do **not** commit so the outer Alembic transaction
|
||||
# controls atomicity.
|
||||
session.flush()
|
||||
|
||||
print(
|
||||
f"Migration completed: {migrated_count} files staged for commit to external storage."
|
||||
)
|
||||
|
||||
|
||||
def _set_tenant_contextvar(session: Session) -> None:
|
||||
"""Set the tenant contextvar to the default schema"""
|
||||
current_tenant = session.execute(text("SELECT current_schema()")).scalar()
|
||||
print(f"Migrating files for tenant: {current_tenant}")
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(current_tenant)
|
||||
@@ -0,0 +1,128 @@
|
||||
"""add_cascade_deletes_to_agent_tables
|
||||
|
||||
Revision ID: ca04500b9ee8
|
||||
Revises: 238b84885828
|
||||
Create Date: 2025-05-30 16:03:51.112263
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ca04500b9ee8"
|
||||
down_revision = "238b84885828"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop existing foreign key constraints
|
||||
op.drop_constraint(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate foreign key constraints with CASCADE delete
|
||||
op.create_foreign_key(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
"chat_message",
|
||||
["primary_question_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
"agent__sub_question",
|
||||
["parent_question_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"agent__sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop CASCADE foreign key constraints
|
||||
op.drop_constraint(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate foreign key constraints without CASCADE delete
|
||||
op.create_foreign_key(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
"chat_message",
|
||||
["primary_question_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
"agent__sub_question",
|
||||
["parent_question_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"agent__sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
29
backend/alembic/versions/cec7ec36c505_kgentity_parent.py
Normal file
29
backend/alembic/versions/cec7ec36c505_kgentity_parent.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""kgentity_parent
|
||||
|
||||
Revision ID: cec7ec36c505
|
||||
Revises: 495cb26ce93e
|
||||
Create Date: 2025-06-07 20:07:46.400770
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "cec7ec36c505"
|
||||
down_revision = "495cb26ce93e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"kg_entity",
|
||||
sa.Column("parent_key", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
# NOTE: you will have to reindex the KG after this migration as the parent_key will be null
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("kg_entity", "parent_key")
|
||||
125
backend/alembic/versions/d09fc20a3c66_seed_builtin_tools.py
Normal file
125
backend/alembic/versions/d09fc20a3c66_seed_builtin_tools.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""seed_builtin_tools
|
||||
|
||||
Revision ID: d09fc20a3c66
|
||||
Revises: b7ec9b5b505f
|
||||
Create Date: 2025-09-09 19:32:16.824373
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d09fc20a3c66"
|
||||
down_revision = "b7ec9b5b505f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Tool definitions - core tools that should always be seeded
|
||||
# Names/in_code_tool_id are the same as the class names in the tool_implementations package
|
||||
BUILT_IN_TOOLS = [
|
||||
{
|
||||
"name": "SearchTool",
|
||||
"display_name": "Internal Search",
|
||||
"description": "The Search Action allows the Assistant to search through connected knowledge to help build an answer.",
|
||||
"in_code_tool_id": "SearchTool",
|
||||
},
|
||||
{
|
||||
"name": "ImageGenerationTool",
|
||||
"display_name": "Image Generation",
|
||||
"description": (
|
||||
"The Image Generation Action allows the assistant to use DALL-E 3 or GPT-IMAGE-1 to generate images. "
|
||||
"The action will be used when the user asks the assistant to generate an image."
|
||||
),
|
||||
"in_code_tool_id": "ImageGenerationTool",
|
||||
},
|
||||
{
|
||||
"name": "WebSearchTool",
|
||||
"display_name": "Web Search",
|
||||
"description": (
|
||||
"The Web Search Action allows the assistant "
|
||||
"to perform internet searches for up-to-date information."
|
||||
),
|
||||
"in_code_tool_id": "WebSearchTool",
|
||||
},
|
||||
{
|
||||
"name": "KnowledgeGraphTool",
|
||||
"display_name": "Knowledge Graph Search",
|
||||
"description": (
|
||||
"The Knowledge Graph Search Action allows the assistant to search the "
|
||||
"Knowledge Graph for information. This tool can (for now) only be active in the KG Beta Assistant, "
|
||||
"and it requires the Knowledge Graph to be enabled."
|
||||
),
|
||||
"in_code_tool_id": "KnowledgeGraphTool",
|
||||
},
|
||||
{
|
||||
"name": "OktaProfileTool",
|
||||
"display_name": "Okta Profile",
|
||||
"description": (
|
||||
"The Okta Profile Action allows the assistant to fetch the current user's information from Okta. "
|
||||
"This may include the user's name, email, phone number, address, and other details such as their "
|
||||
"manager and direct reports."
|
||||
),
|
||||
"in_code_tool_id": "OktaProfileTool",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
|
||||
)
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
if tool["in_code_tool_id"] in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't remove the tools on downgrade since it's totally fine to just
|
||||
# have them around. If we upgrade again, it will be a no-op.
|
||||
pass
|
||||
@@ -11,7 +11,7 @@ import sqlalchemy as sa
|
||||
import json
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.onyx_jira.utils import extract_jira_project
|
||||
from onyx.connectors.jira.utils import extract_jira_project
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -10,12 +10,19 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from onyx.db.search_settings import (
|
||||
get_new_default_embedding_model,
|
||||
get_old_default_embedding_model,
|
||||
user_has_overridden_embedding_model,
|
||||
)
|
||||
from onyx.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from onyx.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from onyx.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from onyx.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from onyx.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.search_settings import user_has_overridden_embedding_model
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.natural_language_processing.search_nlp_models import clean_model_name
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dbaa756c2ccf"
|
||||
@@ -24,6 +31,47 @@ branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def _get_old_default_embedding_model() -> IndexingSetting:
|
||||
is_overridden = user_has_overridden_embedding_model()
|
||||
return IndexingSetting(
|
||||
model_name=(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
),
|
||||
model_dim=(
|
||||
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
),
|
||||
embedding_precision=(EmbeddingPrecision.FLOAT),
|
||||
normalize=(
|
||||
NORMALIZE_EMBEDDINGS
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
),
|
||||
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
index_name="danswer_chunk",
|
||||
multipass_indexing=False,
|
||||
enable_contextual_rag=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
|
||||
def _get_new_default_embedding_model() -> IndexingSetting:
|
||||
return IndexingSetting(
|
||||
model_name=DOCUMENT_ENCODER_MODEL,
|
||||
model_dim=DOC_EMBEDDING_DIM,
|
||||
embedding_precision=(EmbeddingPrecision.BFLOAT16),
|
||||
normalize=NORMALIZE_EMBEDDINGS,
|
||||
query_prefix=ASYM_QUERY_PREFIX,
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
multipass_indexing=False,
|
||||
enable_contextual_rag=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"embedding_model",
|
||||
@@ -61,7 +109,7 @@ def upgrade() -> None:
|
||||
# the user selected via env variables before this change. This is needed since
|
||||
# all index_attempts must be associated with an embedding model, so without this
|
||||
# we will run into violations of non-null contraints
|
||||
old_embedding_model = get_old_default_embedding_model()
|
||||
old_embedding_model = _get_old_default_embedding_model()
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
@@ -79,7 +127,7 @@ def upgrade() -> None:
|
||||
# if the user has not overridden the default embedding model via env variables,
|
||||
# insert the new default model into the database to auto-upgrade them
|
||||
if not user_has_overridden_embedding_model():
|
||||
new_embedding_model = get_new_default_embedding_model()
|
||||
new_embedding_model = _get_new_default_embedding_model()
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
|
||||
@@ -18,11 +18,13 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("DROP TABLE IF EXISTS document CASCADE")
|
||||
op.create_table(
|
||||
"document",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS chunk CASCADE")
|
||||
op.create_table(
|
||||
"chunk",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
@@ -43,6 +45,7 @@ def upgrade() -> None:
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", "document_store_type"),
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS deletion_attempt CASCADE")
|
||||
op.create_table(
|
||||
"deletion_attempt",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
@@ -84,6 +87,7 @@ def upgrade() -> None:
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS document_by_connector_credential_pair CASCADE")
|
||||
op.create_table(
|
||||
"document_by_connector_credential_pair",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
@@ -106,7 +110,10 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# upstream tables first
|
||||
op.drop_table("document_by_connector_credential_pair")
|
||||
op.drop_table("deletion_attempt")
|
||||
op.drop_table("chunk")
|
||||
op.drop_table("document")
|
||||
|
||||
# Alembic op.drop_table() has no "cascade" flag – issue raw SQL
|
||||
op.execute("DROP TABLE IF EXISTS document CASCADE")
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add research_answer_purpose to chat_message
|
||||
|
||||
Revision ID: f8a9b2c3d4e5
|
||||
Revises: 5ae8240accb3
|
||||
Create Date: 2025-01-27 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f8a9b2c3d4e5"
|
||||
down_revision = "5ae8240accb3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add research_answer_purpose column to chat_message table
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_answer_purpose", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove research_answer_purpose column from chat_message table
|
||||
op.drop_column("chat_message", "research_answer_purpose")
|
||||
@@ -0,0 +1,69 @@
|
||||
"""remove foreign key constraints from research_agent_iteration_sub_step
|
||||
|
||||
Revision ID: f9b8c7d6e5a4
|
||||
Revises: bd7c3bf8beba
|
||||
Create Date: 2025-01-27 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f9b8c7d6e5a4"
|
||||
down_revision = "bd7c3bf8beba"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing foreign key constraint for parent_question_id
|
||||
op.drop_constraint(
|
||||
"research_agent_iteration_sub_step_parent_question_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Drop the parent_question_id column entirely
|
||||
op.drop_column("research_agent_iteration_sub_step", "parent_question_id")
|
||||
|
||||
# Drop the foreign key constraint for primary_question_id to chat_message.id
|
||||
# (keep the column as it's needed for the composite foreign key)
|
||||
op.drop_constraint(
|
||||
"research_agent_iteration_sub_step_primary_question_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore the foreign key constraint for primary_question_id to chat_message.id
|
||||
op.create_foreign_key(
|
||||
"research_agent_iteration_sub_step_primary_question_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
"chat_message",
|
||||
["primary_question_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Add back the parent_question_id column
|
||||
op.add_column(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column(
|
||||
"parent_question_id",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Restore the foreign key constraint pointing to research_agent_iteration_sub_step.id
|
||||
op.create_foreign_key(
|
||||
"research_agent_iteration_sub_step_parent_question_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
"research_agent_iteration_sub_step",
|
||||
["parent_question_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
from alembic import context
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
from onyx.db.models import PublicBase
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""add_db_readonly_user
|
||||
|
||||
Revision ID: 3b9f09038764
|
||||
Revises: 3b45e0018bf1
|
||||
Create Date: 2025-05-11 11:05:11.436977
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b9f09038764"
|
||||
down_revision = "3b45e0018bf1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
@@ -4,10 +4,7 @@ from ee.onyx.db.external_perm import fetch_external_groups_for_user
|
||||
from ee.onyx.db.external_perm import fetch_public_external_group_ids
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_documents
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
@@ -18,6 +15,10 @@ from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_access_for_document(
|
||||
@@ -70,9 +71,15 @@ def _get_access_for_documents(
|
||||
for document_id, non_ee_access in non_ee_access_dict.items():
|
||||
document = doc_id_map[document_id]
|
||||
source = doc_id_to_source_map.get(document_id)
|
||||
if source is None:
|
||||
logger.error(f"Document {document_id} has no source")
|
||||
continue
|
||||
|
||||
perm_sync_config = get_source_perm_sync_config(source)
|
||||
is_only_censored = (
|
||||
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
and source not in DOC_PERMISSIONS_FUNC_MAP
|
||||
perm_sync_config
|
||||
and perm_sync_config.censoring_config is not None
|
||||
and perm_sync_config.doc_sync_config is None
|
||||
)
|
||||
|
||||
ext_u_emails = (
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.task_name_builders import query_history_task_name
|
||||
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
|
||||
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
|
||||
from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.background.celery.apps.heavy import celery_app
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
@@ -19,11 +16,10 @@ from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.db.tasks import mark_task_as_started_with_id
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -38,32 +34,50 @@ logger = setup_logger()
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def export_query_history_task(self: Task, *, start: datetime, end: datetime) -> None:
|
||||
def export_query_history_task(
|
||||
self: Task,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
start_time: datetime,
|
||||
# Need to include the tenant_id since the TenantAwareTask needs this
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
if not self.request.id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
task_id = self.request.id
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream,
|
||||
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
register_task(
|
||||
mark_task_as_started_with_id(
|
||||
db_session=db_session,
|
||||
task_name=query_history_task_name(start=start, end=end),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
complete_chat_session_history: list[ChatSessionSnapshot] = (
|
||||
fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start,
|
||||
end=end,
|
||||
feedback_type=None,
|
||||
limit=None,
|
||||
)
|
||||
snapshot_generator = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
for snapshot in snapshot_generator:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
writer.writerows(
|
||||
qa_pair.to_json()
|
||||
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
|
||||
snapshot
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to export query history with {task_id=}")
|
||||
mark_task_as_finished_with_id(
|
||||
@@ -73,37 +87,11 @@ def export_query_history_task(self: Task, *, start: datetime, end: datetime) ->
|
||||
)
|
||||
raise
|
||||
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
complete_chat_session_history = [
|
||||
ChatSessionSnapshot(
|
||||
**chat_session_snapshot.model_dump(), user_email=ONYX_ANONYMIZED_EMAIL
|
||||
)
|
||||
for chat_session_snapshot in complete_chat_session_history
|
||||
]
|
||||
|
||||
qa_pairs: list[QuestionAnswerPairSnapshot] = [
|
||||
qa_pair
|
||||
for chat_session_snapshot in complete_chat_session_history
|
||||
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
|
||||
chat_session_snapshot
|
||||
)
|
||||
]
|
||||
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream,
|
||||
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
|
||||
)
|
||||
writer.writeheader()
|
||||
for row in qa_pairs:
|
||||
writer.writerow(row.to_json())
|
||||
|
||||
report_name = construct_query_history_report_name(task_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store(db_session).save_file(
|
||||
file_name=report_name,
|
||||
get_default_file_store().save_file(
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
@@ -113,6 +101,7 @@ def export_query_history_task(self: Task, *, start: datetime, end: datetime) ->
|
||||
"end": end.isoformat(),
|
||||
"start_time": start_time.isoformat(),
|
||||
},
|
||||
file_id=report_name,
|
||||
)
|
||||
|
||||
delete_task_with_id(
|
||||
@@ -133,6 +122,8 @@ def export_query_history_task(self: Task, *, start: datetime, end: datetime) ->
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
]
|
||||
)
|
||||
|
||||
8
backend/ee/onyx/background/celery/apps/light.py
Normal file
8
backend/ee/onyx/background/celery/apps/light.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from onyx.background.celery.apps.light import celery_app
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
)
|
||||
7
backend/ee/onyx/background/celery/apps/monitoring.py
Normal file
7
backend/ee/onyx/background/celery/apps/monitoring.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from onyx.background.celery.apps.monitoring import celery_app
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
@@ -1,130 +1,12 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.onyx.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# mark as EE for all tasks in this file
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
"ee.onyx.background.celery.tasks.ttl_management",
|
||||
"ee.onyx.background.celery.tasks.usage_reporting",
|
||||
]
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str
|
||||
) -> None:
|
||||
task_id = self.request.id
|
||||
if not task_id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
user_id: UUID | None = None
|
||||
session_id: UUID | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# we generally want to move off this, but keeping for now
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name=OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name=OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
period=None,
|
||||
)
|
||||
|
||||
@@ -20,39 +20,36 @@ from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
ee_beat_system_tasks: list[dict] = []
|
||||
|
||||
ee_beat_task_templates: list[dict] = []
|
||||
ee_beat_task_templates.extend(
|
||||
[
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
ee_beat_task_templates: list[dict] = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
{
|
||||
"name": "export-query-history-cleanup-task",
|
||||
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.CSV_GENERATION,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "export-query-history-cleanup-task",
|
||||
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.CSV_GENERATION,
|
||||
},
|
||||
]
|
||||
)
|
||||
},
|
||||
]
|
||||
|
||||
ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
@@ -60,7 +57,7 @@ if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
|
||||
@@ -3,12 +3,12 @@ from datetime import timedelta
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
from ee.onyx.db.query_history import get_all_query_history_export_tasks
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.db.tasks import get_all_query_history_export_tasks
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
|
||||
104
backend/ee/onyx/background/celery/tasks/cloud/tasks.py
Normal file
104
backend/ee/onyx/background/celery/tasks/cloud/tasks.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import time
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
ignore_result=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def cloud_beat_task_generator(
|
||||
self: Task,
|
||||
task_name: str,
|
||||
queue: str = OnyxCeleryTask.DEFAULT,
|
||||
priority: int = OnyxCeleryPriority.MEDIUM,
|
||||
expires: int = BEAT_EXPIRES_DEFAULT,
|
||||
) -> bool | None:
|
||||
"""a lightweight task used to kick off individual beat tasks per tenant."""
|
||||
time_start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
f"{OnyxRedisLocks.CLOUD_BEAT_TASK_GENERATOR_LOCK}:{task_name}",
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
gated_tenants = get_gated_tenants()
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated_tenants:
|
||||
continue
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# needed in the cloud
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
continue
|
||||
|
||||
self.app.send_task(
|
||||
task_name,
|
||||
kwargs=dict(
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
priority=priority,
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_processed_tenants += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during cloud_beat_task_generator")
|
||||
finally:
|
||||
if not lock_beat.owned():
|
||||
task_logger.error(
|
||||
"cloud_beat_task_generator - Lock not owned on completion"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
else:
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
@@ -16,22 +16,21 @@ from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
from tenacity import retry
|
||||
from tenacity import retry_if_exception
|
||||
from tenacity import stop_after_delay
|
||||
from tenacity import wait_random_exponential
|
||||
|
||||
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.onyx.external_permissions.sync_params import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
@@ -48,8 +47,10 @@ from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document import get_documents_for_connector_credential_pair_limited_columns
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
@@ -58,6 +59,9 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import is_retryable_sqlalchemy_error
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
@@ -73,12 +77,14 @@ from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
|
||||
DOCUMENT_PERMISSIONS_UPDATE_STOP_AFTER = 10 * 60
|
||||
DOCUMENT_PERMISSIONS_UPDATE_MAX_WAIT = 60
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
@@ -86,6 +92,24 @@ LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
def _get_fence_validation_block_expiration() -> int:
|
||||
"""
|
||||
Compute the expiration time for the fence validation block signal.
|
||||
Base expiration is 300 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode.
|
||||
"""
|
||||
base_expiration = 300 # seconds
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return base_expiration
|
||||
|
||||
try:
|
||||
beat_multiplier = OnyxRuntime.get_beat_multiplier()
|
||||
except Exception:
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
|
||||
return int(base_expiration * beat_multiplier)
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off doc permissions sync tasks."""
|
||||
|
||||
|
||||
@@ -99,16 +123,29 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
sync_config = get_source_perm_sync_config(cc_pair.connector.source)
|
||||
if sync_config is None:
|
||||
logger.error(f"No sync config found for {cc_pair.connector.source}")
|
||||
return False
|
||||
|
||||
if sync_config.doc_sync_config is None:
|
||||
logger.error(f"No doc sync config found for {cc_pair.connector.source}")
|
||||
return False
|
||||
|
||||
# if indexing also does perm sync, don't start running doc_sync until at
|
||||
# least one indexing is done
|
||||
if (
|
||||
sync_config.doc_sync_config.initial_index_should_sync
|
||||
and cc_pair.last_successful_index_time is None
|
||||
):
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
last_perm_sync = cc_pair.last_time_perm_sync
|
||||
if last_perm_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
if not source_sync_period:
|
||||
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
|
||||
source_sync_period = sync_config.doc_sync_config.doc_sync_frequency
|
||||
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
@@ -180,7 +217,11 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
"Exception while validating permission sync fences"
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=300)
|
||||
r.set(
|
||||
OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES,
|
||||
1,
|
||||
ex=_get_fence_validation_block_expiration(),
|
||||
)
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
@@ -384,7 +425,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
+ f"_{redis_connector.cc_pair_id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
@@ -411,6 +452,7 @@ def connector_permission_sync_generator_task(
|
||||
created = validate_ccpair_for_user(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
cc_pair.access_type,
|
||||
db_session,
|
||||
enforce_creation=False,
|
||||
)
|
||||
@@ -426,11 +468,15 @@ def connector_permission_sync_generator_task(
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
sync_config = get_source_perm_sync_config(source_type)
|
||||
if sync_config is None:
|
||||
logger.error(f"No sync config found for {source_type}")
|
||||
return None
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
if source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION:
|
||||
if sync_config.doc_sync_config is None:
|
||||
if sync_config.censoring_config:
|
||||
return None
|
||||
|
||||
raise ValueError(
|
||||
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -455,15 +501,31 @@ def connector_permission_sync_generator_task(
|
||||
# this is can be used to determine documents that are "missing" and thus
|
||||
# should no longer be accessible. The decision as to whether we should find
|
||||
# every document during the doc sync process is connector-specific.
|
||||
def fetch_all_existing_docs_fn() -> list[str]:
|
||||
return get_document_ids_for_connector_credential_pair(
|
||||
def fetch_all_existing_docs_fn(
|
||||
sort_order: SortOrder | None = None,
|
||||
) -> list[DocumentRow]:
|
||||
result = get_documents_for_connector_credential_pair_limited_columns(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
return list(result)
|
||||
|
||||
def fetch_all_existing_docs_ids_fn() -> list[str]:
|
||||
result = get_document_ids_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
return result
|
||||
|
||||
doc_sync_func = sync_config.doc_sync_config.doc_sync_func
|
||||
document_external_accesses = doc_sync_func(
|
||||
cc_pair, fetch_all_existing_docs_fn, callback
|
||||
cc_pair,
|
||||
fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn,
|
||||
callback,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
@@ -472,13 +534,13 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
tasks_generated = 0
|
||||
for doc_external_access in document_external_accesses:
|
||||
redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
redis_connector.permissions.update_db(
|
||||
lock=lock,
|
||||
new_permissions=[doc_external_access],
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
task_logger=task_logger,
|
||||
)
|
||||
tasks_generated += 1
|
||||
|
||||
@@ -491,6 +553,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
|
||||
task_logger.warning(
|
||||
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}"
|
||||
)
|
||||
@@ -511,33 +574,28 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
bind=True,
|
||||
# NOTE(rkuo): this should probably move to the db layer
|
||||
@retry(
|
||||
retry=retry_if_exception(is_retryable_sqlalchemy_error),
|
||||
wait=wait_random_exponential(
|
||||
multiplier=1, max=DOCUMENT_PERMISSIONS_UPDATE_MAX_WAIT
|
||||
),
|
||||
stop=stop_after_delay(DOCUMENT_PERMISSIONS_UPDATE_STOP_AFTER),
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
def document_update_permissions(
|
||||
tenant_id: str,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
permissions: DocExternalAccess,
|
||||
source_type_str: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
doc_id = document_external_access.doc_id
|
||||
external_access = document_external_access.external_access
|
||||
doc_id = permissions.doc_id
|
||||
external_access = permissions.external_access
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
@@ -549,7 +607,7 @@ def update_external_document_permissions_task(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_string),
|
||||
source_type=DocumentSource(source_type_str),
|
||||
)
|
||||
|
||||
if created_new_doc:
|
||||
@@ -568,29 +626,17 @@ def update_external_document_permissions_task(
|
||||
f"action=update_permissions "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(
|
||||
f"update_external_document_permissions_task exceptioned: "
|
||||
f"document_update_permissions exceptioned: "
|
||||
f"connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
raise e
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
|
||||
f"document_update_permissions completed: connector_id={connector_id} doc={doc_id}"
|
||||
)
|
||||
|
||||
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
|
||||
return False
|
||||
|
||||
task_logger.info(
|
||||
f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@@ -14,21 +14,23 @@ from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.background.celery.tasks.external_group_syncing.group_sync_utils import (
|
||||
mark_all_relevant_cc_pairs_as_external_group_synced,
|
||||
)
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.onyx.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
|
||||
from ee.onyx.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.onyx.db.external_perm import mark_old_external_groups_as_stale
|
||||
from ee.onyx.db.external_perm import remove_stale_external_groups
|
||||
from ee.onyx.db.external_perm import upsert_external_groups
|
||||
from ee.onyx.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
get_all_cc_pair_agnostic_group_sync_sources,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.external_group_syncing.group_sync_utils import (
|
||||
mark_all_relevant_cc_pairs_as_external_group_synced,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
@@ -40,9 +42,8 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
@@ -57,19 +58,34 @@ from onyx.redis.redis_connector_ext_group_sync import (
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
|
||||
_EXTERNAL_GROUP_BATCH_SIZE = 100
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
def _get_fence_validation_block_expiration() -> int:
|
||||
"""
|
||||
Compute the expiration time for the fence validation block signal.
|
||||
Base expiration is 300 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode.
|
||||
"""
|
||||
base_expiration = 300 # seconds
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return base_expiration
|
||||
|
||||
try:
|
||||
beat_multiplier = OnyxRuntime.get_beat_multiplier()
|
||||
except Exception:
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
|
||||
return int(base_expiration * beat_multiplier)
|
||||
|
||||
|
||||
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
@@ -89,12 +105,20 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
)
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
|
||||
sync_config = get_source_perm_sync_config(cc_pair.connector.source)
|
||||
if sync_config is None:
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"no group sync function for {cc_pair.connector.source}"
|
||||
f"no sync config found for {cc_pair.connector.source}"
|
||||
)
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if sync_config.group_sync_config is None:
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"no group sync config found for {cc_pair.connector.source}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -103,11 +127,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
if last_ext_group_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
source_sync_period = sync_config.group_sync_config.group_sync_frequency
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
|
||||
@@ -147,9 +167,8 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# We only want to sync one cc_pair per source type in
|
||||
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
|
||||
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
|
||||
# For some sources, we only want to sync one cc_pair per source type
|
||||
for source in get_all_cc_pair_agnostic_group_sync_sources():
|
||||
# These are ordered by cc_pair id so the first one is the one we want
|
||||
cc_pairs_to_dedupe = get_cc_pairs_by_source(
|
||||
db_session,
|
||||
@@ -157,8 +176,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
access_type=AccessType.SYNC,
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
)
|
||||
# We only want to sync one cc_pair per source type
|
||||
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
|
||||
# dedupe cc_pairs to only keep the first one
|
||||
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
|
||||
cc_pairs = [
|
||||
cc_pair
|
||||
@@ -197,7 +215,11 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"Exception while validating external group sync fences"
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=300)
|
||||
r.set(
|
||||
OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES,
|
||||
1,
|
||||
ex=_get_fence_validation_block_expiration(),
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -361,7 +383,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
+ f"_{redis_connector.cc_pair_id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
@@ -376,55 +398,12 @@ def connector_external_group_sync_generator_task(
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
_perform_external_group_sync(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
msg = f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
try:
|
||||
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
|
||||
except ConnectorValidationError as e:
|
||||
# TODO: add some notification to the admins here
|
||||
logger.exception(
|
||||
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
logger.info(
|
||||
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
logger.debug(f"New external user groups: {external_user_groups}")
|
||||
|
||||
replace_user__ext_group_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=external_user_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
logger.info(
|
||||
f"Synced {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
|
||||
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -466,6 +445,81 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
|
||||
|
||||
def _perform_external_group_sync(
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
sync_config = get_source_perm_sync_config(source_type)
|
||||
if sync_config is None:
|
||||
msg = f"No sync config found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
if sync_config.group_sync_config is None:
|
||||
msg = f"No group sync config found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
|
||||
|
||||
logger.info(
|
||||
f"Marking old external groups as stale for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
mark_old_external_groups_as_stale(db_session, cc_pair_id)
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
external_user_group_batch: list[ExternalUserGroup] = []
|
||||
try:
|
||||
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
|
||||
for external_user_group in external_user_group_generator:
|
||||
external_user_group_batch.append(external_user_group)
|
||||
if len(external_user_group_batch) >= _EXTERNAL_GROUP_BATCH_SIZE:
|
||||
logger.debug(
|
||||
f"New external user groups: {external_user_group_batch}"
|
||||
)
|
||||
upsert_external_groups(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
external_groups=external_user_group_batch,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
external_user_group_batch = []
|
||||
|
||||
if external_user_group_batch:
|
||||
logger.debug(f"New external user groups: {external_user_group_batch}")
|
||||
upsert_external_groups(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
external_groups=external_user_group_batch,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: add some notification to the admins here
|
||||
logger.exception(
|
||||
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
logger.info(
|
||||
f"Removing stale external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
remove_stale_external_groups(db_session, cc_pair_id)
|
||||
|
||||
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str,
|
||||
celery_app: Celery,
|
||||
@@ -19,7 +19,7 @@ from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.models import AvailableTenant
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
106
backend/ee/onyx/background/celery/tasks/ttl_management/tasks.py
Normal file
106
backend/ee/onyx/background/celery/tasks/ttl_management/tasks.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.onyx.background.task_name_builders import name_chat_ttl_task
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str
|
||||
) -> None:
|
||||
task_id = self.request.id
|
||||
if not task_id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
user_id: UUID | None = None
|
||||
session_id: UUID | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# we generally want to move off this, but keeping for now
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def generate_usage_report_task(
|
||||
self: Task,
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str | None = None,
|
||||
period_from: str | None = None,
|
||||
period_to: str | None = None,
|
||||
) -> None:
|
||||
"""User-initiated usage report generation task"""
|
||||
# Parse period if provided
|
||||
period = None
|
||||
if period_from and period_to:
|
||||
period = (
|
||||
datetime.fromisoformat(period_from),
|
||||
datetime.fromisoformat(period_to),
|
||||
)
|
||||
|
||||
# Generate the report
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=UUID(user_id) if user_id else None,
|
||||
period=period,
|
||||
)
|
||||
@@ -1,38 +0,0 @@
|
||||
from ee.onyx.server.query_and_chat.models import OneShotQAResponse
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_for_answer_api(
|
||||
packets: ChatPacketStream,
|
||||
) -> OneShotQAResponse:
|
||||
response = OneShotQAResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.docs = packet
|
||||
# Extraneous, provided for backwards compatibility
|
||||
response.rephrase = packet.rephrased_query
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.chat_message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
return response
|
||||
@@ -53,6 +53,16 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# JIRA
|
||||
#####
|
||||
|
||||
# In seconds, default is 30 minutes
|
||||
JIRA_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("JIRA_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Google Drive
|
||||
#####
|
||||
@@ -61,6 +71,19 @@ GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# GitHub
|
||||
#####
|
||||
# In seconds, default is 5 minutes
|
||||
GITHUB_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GITHUB_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
# In seconds, default is 5 minutes
|
||||
GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Slack
|
||||
#####
|
||||
@@ -71,6 +94,28 @@ SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
||||
|
||||
#####
|
||||
# Teams
|
||||
#####
|
||||
# In seconds, default is 5 minutes
|
||||
TEAMS_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("TEAMS_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
#####
|
||||
# SharePoint
|
||||
#####
|
||||
# In seconds, default is 30 minutes
|
||||
SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
|
||||
)
|
||||
|
||||
# In seconds, default is 5 minutes
|
||||
SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
|
||||
####
|
||||
# Celery Job Frequency
|
||||
####
|
||||
@@ -94,21 +139,6 @@ JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", "[]"))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
|
||||
)
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
|
||||
)
|
||||
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
# The posthog client does not accept empty API keys or hosts however it fails silently
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
@@ -116,6 +146,4 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
28
backend/ee/onyx/connectors/perm_sync_valid.py
Normal file
28
backend/ee/onyx/connectors/perm_sync_valid.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
|
||||
|
||||
def validate_confluence_perm_sync(connector: ConfluenceConnector) -> None:
|
||||
"""
|
||||
Validate that the connector is configured correctly for permissions syncing.
|
||||
"""
|
||||
|
||||
|
||||
def validate_drive_perm_sync(connector: GoogleDriveConnector) -> None:
|
||||
"""
|
||||
Validate that the connector is configured correctly for permissions syncing.
|
||||
"""
|
||||
|
||||
|
||||
def validate_perm_sync(connector: BaseConnector) -> None:
|
||||
"""
|
||||
Override this if your connector needs to validate permissions syncing.
|
||||
Raise an exception if invalid, otherwise do nothing.
|
||||
|
||||
Default is a no-op (always successful).
|
||||
"""
|
||||
if isinstance(connector, ConfluenceConnector):
|
||||
validate_confluence_perm_sync(connector)
|
||||
elif isinstance(connector, GoogleDriveConnector):
|
||||
validate_drive_perm_sync(connector)
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
@@ -62,20 +63,41 @@ def delete_public_external_group_for_cc_pair__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def replace_user__ext_group_for_cc_pair(
|
||||
def mark_old_external_groups_as_stale(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
group_defs: list[ExternalUserGroup],
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
update(User__ExternalUserGroupId)
|
||||
.where(User__ExternalUserGroupId.cc_pair_id == cc_pair_id)
|
||||
.values(stale=True)
|
||||
)
|
||||
db_session.execute(
|
||||
update(PublicExternalUserGroup)
|
||||
.where(PublicExternalUserGroup.cc_pair_id == cc_pair_id)
|
||||
.values(stale=True)
|
||||
)
|
||||
|
||||
|
||||
def upsert_external_groups(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
external_groups: list[ExternalUserGroup],
|
||||
source: DocumentSource,
|
||||
) -> None:
|
||||
"""
|
||||
This function clears all existing external user group relations for a given cc_pair_id
|
||||
and replaces them with the new group definitions and commits the changes.
|
||||
Performs a true upsert operation for external user groups:
|
||||
- For existing groups (same user_id, external_user_group_id, cc_pair_id), updates the stale flag to False
|
||||
- For new groups, inserts them with stale=False
|
||||
- For public groups, uses upsert logic as well
|
||||
"""
|
||||
# If there are no groups to add, return early
|
||||
if not external_groups:
|
||||
return
|
||||
|
||||
# collect all emails from all groups to batch add all users at once for efficiency
|
||||
all_group_member_emails = set()
|
||||
for external_group in group_defs:
|
||||
for external_group in external_groups:
|
||||
for user_email in external_group.user_emails:
|
||||
all_group_member_emails.add(user_email)
|
||||
|
||||
@@ -86,26 +108,17 @@ def replace_user__ext_group_for_cc_pair(
|
||||
emails=list(all_group_member_emails),
|
||||
)
|
||||
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
delete_public_external_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# map emails to ids
|
||||
email_id_map = {user.email: user.id for user in all_group_members}
|
||||
email_id_map = {user.email.lower(): user.id for user in all_group_members}
|
||||
|
||||
# use these ids to create new external user group relations relating group_id to user_ids
|
||||
new_external_permissions: list[User__ExternalUserGroupId] = []
|
||||
new_public_external_groups: list[PublicExternalUserGroup] = []
|
||||
for external_group in group_defs:
|
||||
# Process each external group
|
||||
for external_group in external_groups:
|
||||
external_group_id = build_ext_group_name_for_onyx(
|
||||
ext_group_name=external_group.id,
|
||||
source=source,
|
||||
)
|
||||
|
||||
# Handle user-group mappings
|
||||
for user_email in external_group.user_emails:
|
||||
user_id = email_id_map.get(user_email.lower())
|
||||
if user_id is None:
|
||||
@@ -114,24 +127,71 @@ def replace_user__ext_group_for_cc_pair(
|
||||
f" with email {user_email} not found"
|
||||
)
|
||||
continue
|
||||
new_external_permissions.append(
|
||||
User__ExternalUserGroupId(
|
||||
|
||||
# Check if the user-group mapping already exists
|
||||
existing_user_group = db_session.scalar(
|
||||
select(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.user_id == user_id,
|
||||
User__ExternalUserGroupId.external_user_group_id
|
||||
== external_group_id,
|
||||
User__ExternalUserGroupId.cc_pair_id == cc_pair_id,
|
||||
)
|
||||
)
|
||||
|
||||
if existing_user_group:
|
||||
# Update existing record
|
||||
existing_user_group.stale = False
|
||||
else:
|
||||
# Insert new record
|
||||
new_user_group = User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=external_group_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
stale=False,
|
||||
)
|
||||
db_session.add(new_user_group)
|
||||
|
||||
# Handle public group if needed
|
||||
if external_group.gives_anyone_access:
|
||||
# Check if the public group already exists
|
||||
existing_public_group = db_session.scalar(
|
||||
select(PublicExternalUserGroup).where(
|
||||
PublicExternalUserGroup.external_user_group_id == external_group_id,
|
||||
PublicExternalUserGroup.cc_pair_id == cc_pair_id,
|
||||
)
|
||||
)
|
||||
|
||||
if external_group.gives_anyone_access:
|
||||
new_public_external_groups.append(
|
||||
PublicExternalUserGroup(
|
||||
if existing_public_group:
|
||||
# Update existing record
|
||||
existing_public_group.stale = False
|
||||
else:
|
||||
# Insert new record
|
||||
new_public_group = PublicExternalUserGroup(
|
||||
external_user_group_id=external_group_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
stale=False,
|
||||
)
|
||||
)
|
||||
db_session.add(new_public_group)
|
||||
|
||||
db_session.add_all(new_external_permissions)
|
||||
db_session.add_all(new_public_external_groups)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_stale_external_groups(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.cc_pair_id == cc_pair_id,
|
||||
User__ExternalUserGroupId.stale.is_(True),
|
||||
)
|
||||
)
|
||||
db_session.execute(
|
||||
delete(PublicExternalUserGroup).where(
|
||||
PublicExternalUserGroup.cc_pair_id == cc_pair_id,
|
||||
PublicExternalUserGroup.stale.is_(True),
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -15,10 +15,13 @@ from sqlalchemy.sql import select
|
||||
from sqlalchemy.sql.expression import literal
|
||||
from sqlalchemy.sql.expression import UnaryExpression
|
||||
|
||||
from ee.onyx.background.task_name_builders import QUERY_HISTORY_TASK_NAME_PREFIX
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessageFeedback
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import TaskQueueState
|
||||
from onyx.db.tasks import get_all_tasks_with_prefix
|
||||
|
||||
|
||||
def _build_filter_conditions(
|
||||
@@ -171,3 +174,9 @@ def fetch_chat_sessions_eagerly_by_time(
|
||||
chat_sessions = query.all()
|
||||
|
||||
return chat_sessions
|
||||
|
||||
|
||||
def get_all_query_history_export_tasks(
|
||||
db_session: Session,
|
||||
) -> list[TaskQueueState]:
|
||||
return get_all_tasks_with_prefix(db_session, QUERY_HISTORY_TASK_NAME_PREFIX)
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import IO
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_users_db_sqlalchemy import UUID_ID
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
|
||||
@@ -13,6 +15,7 @@ from ee.onyx.server.reporting.usage_export_models import FlowType
|
||||
from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.models import UsageReport
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
|
||||
|
||||
@@ -86,25 +89,49 @@ def get_all_empty_chat_message_entries(
|
||||
|
||||
|
||||
def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
|
||||
# Get the user emails
|
||||
usage_reports = db_session.query(UsageReport).all()
|
||||
user_ids = {r.requestor_user_id for r in usage_reports if r.requestor_user_id}
|
||||
user_emails = {
|
||||
user.id: user.email
|
||||
for user in db_session.query(User)
|
||||
.filter(cast(User.id, UUID).in_(user_ids))
|
||||
.all()
|
||||
}
|
||||
|
||||
return [
|
||||
UsageReportMetadata(
|
||||
report_name=r.report_name,
|
||||
requestor=str(r.requestor_user_id) if r.requestor_user_id else None,
|
||||
requestor=(
|
||||
user_emails.get(r.requestor_user_id) if r.requestor_user_id else None
|
||||
),
|
||||
time_created=r.time_created,
|
||||
period_from=r.period_from,
|
||||
period_to=r.period_to,
|
||||
)
|
||||
for r in db_session.query(UsageReport).all()
|
||||
for r in usage_reports
|
||||
]
|
||||
|
||||
|
||||
def get_usage_report_data(
|
||||
db_session: Session,
|
||||
report_name: str,
|
||||
report_display_name: str,
|
||||
) -> IO:
|
||||
file_store = get_default_file_store(db_session)
|
||||
"""
|
||||
Get the usage report data from the file store.
|
||||
|
||||
Args:
|
||||
db_session: The database session.
|
||||
report_display_name: The display name of the usage report. Also assumes
|
||||
that the file is stored with this as the ID in the file store.
|
||||
|
||||
Returns:
|
||||
The usage report data.
|
||||
"""
|
||||
file_store = get_default_file_store()
|
||||
# usage report may be very large, so don't load it all into memory
|
||||
return file_store.read_file(file_name=report_name, mode="b", use_tempfile=True)
|
||||
return file_store.read_file(
|
||||
file_id=report_display_name, mode="b", use_tempfile=True
|
||||
)
|
||||
|
||||
|
||||
def write_usage_report(
|
||||
|
||||
@@ -128,11 +128,14 @@ def validate_object_creation_for_user(
|
||||
target_group_ids: list[int] | None = None,
|
||||
object_is_public: bool | None = None,
|
||||
object_is_perm_sync: bool | None = None,
|
||||
object_is_owned_by_user: bool = False,
|
||||
object_is_new: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
All users can create/edit permission synced objects if they don't specify a group
|
||||
All admin actions are allowed.
|
||||
Prevents non-admins from creating/editing:
|
||||
Curators and global curators can create public objects.
|
||||
Prevents other non-admins from creating/editing:
|
||||
- public objects
|
||||
- objects with no groups
|
||||
- objects that belong to a group they don't curate
|
||||
@@ -143,13 +146,23 @@ def validate_object_creation_for_user(
|
||||
if not user or user.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
if object_is_public:
|
||||
detail = "User does not have permission to create public credentials"
|
||||
# Allow curators and global curators to create public objects
|
||||
# w/o associated groups IF the object is new/owned by them
|
||||
if (
|
||||
object_is_public
|
||||
and user.role in [UserRole.CURATOR, UserRole.GLOBAL_CURATOR]
|
||||
and (object_is_new or object_is_owned_by_user)
|
||||
):
|
||||
return
|
||||
|
||||
if object_is_public and user.role == UserRole.BASIC:
|
||||
detail = "User does not have permission to create public objects"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
if not target_group_ids:
|
||||
detail = "Curators must specify 1+ groups"
|
||||
logger.error(detail)
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<services version="1.0">
|
||||
<container id="default" version="1.0">
|
||||
<document-api />
|
||||
<search />
|
||||
<http>
|
||||
<server id="default" port="4080" />
|
||||
</http>
|
||||
<nodes count="[2, 4]">
|
||||
<resources vcpu="4.0" memory="16Gb" architecture="arm64" storage-type="remote"
|
||||
disk="48Gb" />
|
||||
</nodes>
|
||||
|
||||
|
||||
</container>
|
||||
<content id="danswer_index" version="1.0">
|
||||
<documents>
|
||||
<!-- <document type="danswer_chunk" mode="index" /> -->
|
||||
{{ document_elements }}
|
||||
</documents>
|
||||
<nodes count="60">
|
||||
<resources vcpu="8.0" memory="128.0Gb" architecture="arm64" storage-type="local"
|
||||
disk="475.0Gb" />
|
||||
</nodes>
|
||||
<engine>
|
||||
<proton>
|
||||
<tuning>
|
||||
<searchnode>
|
||||
<requestthreads>
|
||||
<persearch>2</persearch>
|
||||
</requestthreads>
|
||||
</searchnode>
|
||||
</tuning>
|
||||
</proton>
|
||||
</engine>
|
||||
|
||||
<config name="vespa.config.search.summary.juniperrc">
|
||||
<max_matches>3</max_matches>
|
||||
<length>750</length>
|
||||
<surround_max>350</surround_max>
|
||||
<min_length>300</min_length>
|
||||
</config>
|
||||
|
||||
|
||||
<min-redundancy>2</min-redundancy>
|
||||
|
||||
</content>
|
||||
</services>
|
||||
@@ -2,3 +2,6 @@
|
||||
# Instead of setting a page to public, we just add this group so that the page
|
||||
# is only accessible to users who have confluence accounts.
|
||||
ALL_CONF_EMAILS_GROUP_NAME = "All_Confluence_Users_Found_By_Onyx"
|
||||
|
||||
VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
|
||||
REQUEST_PAGINATION_LIMIT = 5000
|
||||
|
||||
@@ -4,20 +4,14 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -25,374 +19,14 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
|
||||
_REQUEST_PAGINATION_LIMIT = 5000
|
||||
|
||||
|
||||
def _get_server_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions = confluence_client.get_all_space_permissions_server(
|
||||
space_key=space_key
|
||||
)
|
||||
|
||||
viewspace_permissions = []
|
||||
for permission_category in space_permissions:
|
||||
if permission_category.get("type") == _VIEWSPACE_PERMISSION_TYPE:
|
||||
viewspace_permissions.extend(
|
||||
permission_category.get("spacePermissions", [])
|
||||
)
|
||||
|
||||
is_public = False
|
||||
user_names = set()
|
||||
group_names = set()
|
||||
for permission in viewspace_permissions:
|
||||
user_name = permission.get("userName")
|
||||
if user_name:
|
||||
user_names.add(user_name)
|
||||
group_name = permission.get("groupName")
|
||||
if group_name:
|
||||
group_names.add(group_name)
|
||||
|
||||
# It seems that if anonymous access is turned on for the site and space,
|
||||
# then the space is publicly accessible.
|
||||
# For confluence server, we make a group that contains all users
|
||||
# that exist in confluence and then just add that group to the space permissions
|
||||
# if anonymous access is turned on for the site and space or we set is_public = True
|
||||
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
|
||||
# that we can support confluence server deployments that want anonymous access
|
||||
# to be public (we cant test this because its paywalled)
|
||||
if user_name is None and group_name is None:
|
||||
# Defaults to False
|
||||
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
|
||||
is_public = True
|
||||
else:
|
||||
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
|
||||
|
||||
user_emails = set()
|
||||
for user_name in user_names:
|
||||
user_email = get_user_email_from_username__server(confluence_client, user_name)
|
||||
if user_email:
|
||||
user_emails.add(user_email)
|
||||
else:
|
||||
logger.warning(f"Email for user {user_name} not found in Confluence")
|
||||
|
||||
if not user_emails and not group_names:
|
||||
logger.warning(
|
||||
"No user emails or group names found in Confluence space permissions"
|
||||
f"\nSpace key: {space_key}"
|
||||
f"\nSpace permissions: {space_permissions}"
|
||||
)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_cloud_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions_result = confluence_client.get_space(
|
||||
space_key=space_key, expand="permissions"
|
||||
)
|
||||
space_permissions = space_permissions_result.get("permissions", [])
|
||||
|
||||
user_emails = set()
|
||||
group_names = set()
|
||||
is_externally_public = False
|
||||
for permission in space_permissions:
|
||||
subs = permission.get("subjects")
|
||||
if subs:
|
||||
# If there are subjects, then there are explicit users or groups with access
|
||||
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
|
||||
user_emails.add(email)
|
||||
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
|
||||
group_names.add(group_name)
|
||||
else:
|
||||
# If there are no subjects, then the permission is for everyone
|
||||
if permission.get("operation", {}).get(
|
||||
"operation"
|
||||
) == "read" and permission.get("anonymousAccess", False):
|
||||
# If the permission specifies read access for anonymous users, then
|
||||
# the space is publicly accessible
|
||||
is_externally_public = True
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_externally_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_space_permissions(
|
||||
confluence_client: OnyxConfluence,
|
||||
is_cloud: bool,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
logger.debug("Getting space permissions")
|
||||
# Gets all the spaces in the Confluence instance
|
||||
all_space_keys = []
|
||||
start = 0
|
||||
while True:
|
||||
spaces_batch = confluence_client.get_all_spaces(
|
||||
start=start, limit=_REQUEST_PAGINATION_LIMIT
|
||||
)
|
||||
for space in spaces_batch.get("results", []):
|
||||
all_space_keys.append(space.get("key"))
|
||||
|
||||
if len(spaces_batch.get("results", [])) < _REQUEST_PAGINATION_LIMIT:
|
||||
break
|
||||
|
||||
start += len(spaces_batch.get("results", []))
|
||||
|
||||
# Gets the permissions for each space
|
||||
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
|
||||
for space_key in all_space_keys:
|
||||
if is_cloud:
|
||||
space_permissions = _get_cloud_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
else:
|
||||
space_permissions = _get_server_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
|
||||
# Stores the permissions for each space
|
||||
space_permissions_by_space_key[space_key] = space_permissions
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
and not space_permissions.external_user_emails
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"No permissions found for space '{space_key}'. This is very unlikely"
|
||||
"to be correct and is more likely caused by an access token with"
|
||||
"insufficient permissions. Make sure that the access token has Admin"
|
||||
f"permissions for space '{space_key}'"
|
||||
)
|
||||
|
||||
return space_permissions_by_space_key
|
||||
|
||||
|
||||
def _extract_read_access_restrictions(
|
||||
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
|
||||
) -> tuple[set[str], set[str], bool]:
|
||||
"""
|
||||
Converts a page's restrictions dict into an ExternalAccess object.
|
||||
If there are no restrictions, then return None
|
||||
"""
|
||||
read_access = restrictions.get("read", {})
|
||||
read_access_restrictions = read_access.get("restrictions", {})
|
||||
|
||||
# Extract the users with read access
|
||||
read_access_user = read_access_restrictions.get("user", {})
|
||||
read_access_user_jsons = read_access_user.get("results", [])
|
||||
# any items found means that there is a restriction
|
||||
found_any_restriction = bool(read_access_user_jsons)
|
||||
|
||||
read_access_user_emails = []
|
||||
for user in read_access_user_jsons:
|
||||
# If the user has an email, then add it to the list
|
||||
if user.get("email"):
|
||||
read_access_user_emails.append(user["email"])
|
||||
# If the user has a username and not an email, then get the email from Confluence
|
||||
elif user.get("username"):
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client, user_name=user["username"]
|
||||
)
|
||||
if email:
|
||||
read_access_user_emails.append(email)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Email for user {user['username']} not found in Confluence"
|
||||
)
|
||||
else:
|
||||
if user.get("email") is not None:
|
||||
logger.warning(f"Cant find email for user {user.get('displayName')}")
|
||||
logger.warning(
|
||||
"This user needs to make their email accessible in Confluence Settings"
|
||||
)
|
||||
|
||||
logger.warning(f"no user email or username for {user}")
|
||||
|
||||
# Extract the groups with read access
|
||||
read_access_group = read_access_restrictions.get("group", {})
|
||||
read_access_group_jsons = read_access_group.get("results", [])
|
||||
# any items found means that there is a restriction
|
||||
found_any_restriction |= bool(read_access_group_jsons)
|
||||
read_access_group_names = [
|
||||
group["name"] for group in read_access_group_jsons if group.get("name")
|
||||
]
|
||||
|
||||
return (
|
||||
set(read_access_user_emails),
|
||||
set(read_access_group_names),
|
||||
found_any_restriction,
|
||||
)
|
||||
|
||||
|
||||
def _get_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
perm_sync_data: dict[str, Any],
|
||||
) -> ExternalAccess | None:
|
||||
"""
|
||||
This function gets the restrictions for a page. In Confluence, a child can have
|
||||
at MOST the same level accessibility as its immediate parent.
|
||||
|
||||
If no restrictions are found anywhere, then return None, indicating that the page
|
||||
should inherit the space's restrictions.
|
||||
"""
|
||||
found_user_emails: set[str] = set()
|
||||
found_group_names: set[str] = set()
|
||||
|
||||
# NOTE: need the found_any_restriction, since we can find restrictions
|
||||
# but not be able to extract any user emails or group names
|
||||
# in this case, we should just give no access
|
||||
found_user_emails, found_group_names, found_any_page_level_restriction = (
|
||||
_extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
)
|
||||
# if there are individual page-level restrictions, then this is the accurate
|
||||
# restriction for the page. You cannot both have page-level restrictions AND
|
||||
# inherit restrictions from the parent.
|
||||
if found_any_page_level_restriction:
|
||||
return ExternalAccess(
|
||||
external_user_emails=found_user_emails,
|
||||
external_user_group_ids=found_group_names,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
ancestors: list[dict[str, Any]] = perm_sync_data.get("ancestors", [])
|
||||
# ancestors seem to be in order from root to immediate parent
|
||||
# https://community.atlassian.com/forums/Confluence-questions/Order-of-ancestors-in-REST-API-response-Confluence-Server-amp/qaq-p/2385981
|
||||
# we want the restrictions from the immediate parent to take precedence, so we should
|
||||
# reverse the list
|
||||
for ancestor in reversed(ancestors):
|
||||
(
|
||||
ancestor_user_emails,
|
||||
ancestor_group_names,
|
||||
found_any_restrictions_in_ancestor,
|
||||
) = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=ancestor.get("restrictions", {}),
|
||||
)
|
||||
if found_any_restrictions_in_ancestor:
|
||||
# if inheriting restrictions from the parent, then the first one we run into
|
||||
# should be applied (the reason why we'd traverse more than one ancestor is if
|
||||
# the ancestor also is in "inherit" mode.)
|
||||
logger.info(
|
||||
f"Found user restrictions {ancestor_user_emails} and group restrictions {ancestor_group_names}"
|
||||
f"for document {perm_sync_data.get('id')} based on ancestor {ancestor}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails=ancestor_user_emails,
|
||||
external_user_group_ids=ancestor_group_names,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
# we didn't find any restrictions, so the page inherits the space's restrictions
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
Otherwise, use the space's restrictions.
|
||||
"""
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
)
|
||||
|
||||
if restrictions := _get_all_page_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
perm_sync_data=slim_doc.perm_sync_data,
|
||||
):
|
||||
logger.info(f"Found restrictions {restrictions} for document {slim_doc.id}")
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=restrictions,
|
||||
)
|
||||
# If there are restrictions, then we don't need to use the space's restrictions
|
||||
continue
|
||||
|
||||
space_key = slim_doc.perm_sync_data.get("space_key")
|
||||
if not (space_permissions := space_permissions_by_space_key.get(space_key)):
|
||||
logger.warning(
|
||||
f"Individually fetching space permissions for space {space_key}. This is "
|
||||
"unexpected. It means the permissions were not able to fetched initially."
|
||||
)
|
||||
try:
|
||||
# If the space permissions are not in the cache, then fetch them
|
||||
if is_cloud:
|
||||
retrieved_space_permissions = _get_cloud_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
else:
|
||||
retrieved_space_permissions = _get_server_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
space_permissions_by_space_key[space_key] = retrieved_space_permissions
|
||||
space_permissions = retrieved_space_permissions
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error fetching space permissions for space {space_key}: {e}"
|
||||
)
|
||||
|
||||
if not space_permissions:
|
||||
logger.warning(
|
||||
f"No permissions found for document {slim_doc.id} in space {space_key}"
|
||||
)
|
||||
# be safe, if we can't get the permissions then make the document inaccessible
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
# If there are no restrictions, then use the space's restrictions
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
and not space_permissions.external_user_emails
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"Permissions are empty for document: {slim_doc.id}\n"
|
||||
"This means space permissions may be wrong for"
|
||||
f" Space key: {space_key}"
|
||||
)
|
||||
|
||||
logger.info("Finished fetching all page restrictions")
|
||||
CONFLUENCE_DOC_SYNC_LABEL = "confluence_doc_sync"
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
@@ -400,7 +34,6 @@ def confluence_doc_sync(
|
||||
Compares fetched documents against existing documents in the DB for the connector.
|
||||
If a document exists in the DB but not in the Confluence fetch, it's marked as restricted.
|
||||
"""
|
||||
logger.info(f"Starting confluence doc sync for CC Pair ID: {cc_pair.id}")
|
||||
confluence_connector = ConfluenceConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
@@ -410,62 +43,11 @@ def confluence_doc_sync(
|
||||
)
|
||||
confluence_connector.set_credentials_provider(provider)
|
||||
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
|
||||
space_permissions_by_space_key = _get_space_permissions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
is_cloud=is_cloud,
|
||||
)
|
||||
logger.info("Space permissions by space key:")
|
||||
for space_key, space_permissions in space_permissions_by_space_key.items():
|
||||
logger.info(f"Space key: {space_key}, Permissions: {space_permissions}")
|
||||
|
||||
slim_docs: list[SlimDocument] = []
|
||||
logger.info("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents(
|
||||
callback=callback
|
||||
):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync", 1)
|
||||
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
# Find documents that are no longer accessible in Confluence
|
||||
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id}")
|
||||
existing_doc_ids = fetch_all_existing_docs_fn()
|
||||
|
||||
# Find missing doc IDs
|
||||
fetched_doc_ids = {doc.id for doc in slim_docs}
|
||||
missing_doc_ids = set(existing_doc_ids) - fetched_doc_ids
|
||||
|
||||
# Yield access removal for missing docs. Better to be safe.
|
||||
if missing_doc_ids:
|
||||
logger.warning(
|
||||
f"Found {len(missing_doc_ids)} documents that are in the DB but "
|
||||
"not present in Confluence fetch. Making them inaccessible."
|
||||
)
|
||||
for missing_id in missing_doc_ids:
|
||||
logger.warning(f"Removing access for document ID: {missing_id}")
|
||||
yield DocExternalAccess(
|
||||
doc_id=missing_id,
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Fetching all page restrictions for fetched documents")
|
||||
yield from _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.CONFLUENCE,
|
||||
slim_connector=confluence_connector,
|
||||
label=CONFLUENCE_DOC_SYNC_LABEL,
|
||||
)
|
||||
|
||||
logger.info("Finished confluence doc sync")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
@@ -65,7 +67,7 @@ def _build_group_member_email_map(
|
||||
def confluence_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
|
||||
@@ -89,10 +91,10 @@ def confluence_group_sync(
|
||||
confluence_client=confluence_client,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
|
||||
all_found_emails = set()
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
onyx_groups.append(
|
||||
yield (
|
||||
ExternalUserGroup(
|
||||
id=group_id,
|
||||
user_emails=list(group_member_emails),
|
||||
@@ -107,6 +109,4 @@ def confluence_group_sync(
|
||||
id=ALL_CONF_EMAILS_GROUP_NAME,
|
||||
user_emails=list(all_found_emails),
|
||||
)
|
||||
onyx_groups.append(all_found_group)
|
||||
|
||||
return onyx_groups
|
||||
yield all_found_group
|
||||
|
||||
133
backend/ee/onyx/external_permissions/confluence/page_access.py
Normal file
133
backend/ee/onyx/external_permissions/confluence/page_access.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_read_access_restrictions(
|
||||
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
|
||||
) -> tuple[set[str], set[str], bool]:
|
||||
"""
|
||||
Converts a page's restrictions dict into an ExternalAccess object.
|
||||
If there are no restrictions, then return None
|
||||
"""
|
||||
read_access = restrictions.get("read", {})
|
||||
read_access_restrictions = read_access.get("restrictions", {})
|
||||
|
||||
# Extract the users with read access
|
||||
read_access_user = read_access_restrictions.get("user", {})
|
||||
read_access_user_jsons = read_access_user.get("results", [])
|
||||
# any items found means that there is a restriction
|
||||
found_any_restriction = bool(read_access_user_jsons)
|
||||
|
||||
read_access_user_emails = []
|
||||
for user in read_access_user_jsons:
|
||||
# If the user has an email, then add it to the list
|
||||
if user.get("email"):
|
||||
read_access_user_emails.append(user["email"])
|
||||
# If the user has a username and not an email, then get the email from Confluence
|
||||
elif user.get("username"):
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client, user_name=user["username"]
|
||||
)
|
||||
if email:
|
||||
read_access_user_emails.append(email)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Email for user {user['username']} not found in Confluence"
|
||||
)
|
||||
else:
|
||||
if user.get("email") is not None:
|
||||
logger.warning(f"Cant find email for user {user.get('displayName')}")
|
||||
logger.warning(
|
||||
"This user needs to make their email accessible in Confluence Settings"
|
||||
)
|
||||
|
||||
logger.warning(f"no user email or username for {user}")
|
||||
|
||||
# Extract the groups with read access
|
||||
read_access_group = read_access_restrictions.get("group", {})
|
||||
read_access_group_jsons = read_access_group.get("results", [])
|
||||
# any items found means that there is a restriction
|
||||
found_any_restriction |= bool(read_access_group_jsons)
|
||||
read_access_group_names = [
|
||||
group["name"] for group in read_access_group_jsons if group.get("name")
|
||||
]
|
||||
|
||||
return (
|
||||
set(read_access_user_emails),
|
||||
set(read_access_group_names),
|
||||
found_any_restriction,
|
||||
)
|
||||
|
||||
|
||||
def get_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
page_id: str,
|
||||
page_restrictions: dict[str, Any],
|
||||
ancestors: list[dict[str, Any]],
|
||||
) -> ExternalAccess | None:
|
||||
"""
|
||||
This function gets the restrictions for a page. In Confluence, a child can have
|
||||
at MOST the same level accessibility as its immediate parent.
|
||||
|
||||
If no restrictions are found anywhere, then return None, indicating that the page
|
||||
should inherit the space's restrictions.
|
||||
"""
|
||||
found_user_emails: set[str] = set()
|
||||
found_group_names: set[str] = set()
|
||||
|
||||
# NOTE: need the found_any_restriction, since we can find restrictions
|
||||
# but not be able to extract any user emails or group names
|
||||
# in this case, we should just give no access
|
||||
found_user_emails, found_group_names, found_any_page_level_restriction = (
|
||||
_extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=page_restrictions,
|
||||
)
|
||||
)
|
||||
# if there are individual page-level restrictions, then this is the accurate
|
||||
# restriction for the page. You cannot both have page-level restrictions AND
|
||||
# inherit restrictions from the parent.
|
||||
if found_any_page_level_restriction:
|
||||
return ExternalAccess(
|
||||
external_user_emails=found_user_emails,
|
||||
external_user_group_ids=found_group_names,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
# ancestors seem to be in order from root to immediate parent
|
||||
# https://community.atlassian.com/forums/Confluence-questions/Order-of-ancestors-in-REST-API-response-Confluence-Server-amp/qaq-p/2385981
|
||||
# we want the restrictions from the immediate parent to take precedence, so we should
|
||||
# reverse the list
|
||||
for ancestor in reversed(ancestors):
|
||||
(
|
||||
ancestor_user_emails,
|
||||
ancestor_group_names,
|
||||
found_any_restrictions_in_ancestor,
|
||||
) = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=ancestor.get("restrictions", {}),
|
||||
)
|
||||
if found_any_restrictions_in_ancestor:
|
||||
# if inheriting restrictions from the parent, then the first one we run into
|
||||
# should be applied (the reason why we'd traverse more than one ancestor is if
|
||||
# the ancestor also is in "inherit" mode.)
|
||||
logger.debug(
|
||||
f"Found user restrictions {ancestor_user_emails} and group restrictions {ancestor_group_names}"
|
||||
f"for document {page_id} based on ancestor {ancestor}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails=ancestor_user_emails,
|
||||
external_user_group_ids=ancestor_group_names,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
# we didn't find any restrictions, so the page inherits the space's restrictions
|
||||
return None
|
||||
165
backend/ee/onyx/external_permissions/confluence/space_access.py
Normal file
165
backend/ee/onyx/external_permissions/confluence/space_access.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from ee.onyx.external_permissions.confluence.constants import REQUEST_PAGINATION_LIMIT
|
||||
from ee.onyx.external_permissions.confluence.constants import VIEWSPACE_PERMISSION_TYPE
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_server_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions = confluence_client.get_all_space_permissions_server(
|
||||
space_key=space_key
|
||||
)
|
||||
|
||||
viewspace_permissions = []
|
||||
for permission_category in space_permissions:
|
||||
if permission_category.get("type") == VIEWSPACE_PERMISSION_TYPE:
|
||||
viewspace_permissions.extend(
|
||||
permission_category.get("spacePermissions", [])
|
||||
)
|
||||
|
||||
is_public = False
|
||||
user_names = set()
|
||||
group_names = set()
|
||||
for permission in viewspace_permissions:
|
||||
if user_name := permission.get("userName"):
|
||||
user_names.add(user_name)
|
||||
if group_name := permission.get("groupName"):
|
||||
group_names.add(group_name)
|
||||
|
||||
# It seems that if anonymous access is turned on for the site and space,
|
||||
# then the space is publicly accessible.
|
||||
# For confluence server, we make a group that contains all users
|
||||
# that exist in confluence and then just add that group to the space permissions
|
||||
# if anonymous access is turned on for the site and space or we set is_public = True
|
||||
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
|
||||
# that we can support confluence server deployments that want anonymous access
|
||||
# to be public (we cant test this because its paywalled)
|
||||
if user_name is None and group_name is None:
|
||||
# Defaults to False
|
||||
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
|
||||
is_public = True
|
||||
else:
|
||||
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
|
||||
|
||||
user_emails = set()
|
||||
for user_name in user_names:
|
||||
user_email = get_user_email_from_username__server(confluence_client, user_name)
|
||||
if user_email:
|
||||
user_emails.add(user_email)
|
||||
else:
|
||||
logger.warning(f"Email for user {user_name} not found in Confluence")
|
||||
|
||||
if not user_emails and not group_names:
|
||||
logger.warning(
|
||||
"No user emails or group names found in Confluence space permissions"
|
||||
f"\nSpace key: {space_key}"
|
||||
f"\nSpace permissions: {space_permissions}"
|
||||
)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_cloud_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions_result = confluence_client.get_space(
|
||||
space_key=space_key, expand="permissions"
|
||||
)
|
||||
space_permissions = space_permissions_result.get("permissions", [])
|
||||
|
||||
user_emails = set()
|
||||
group_names = set()
|
||||
is_externally_public = False
|
||||
for permission in space_permissions:
|
||||
subs = permission.get("subjects")
|
||||
if subs:
|
||||
# If there are subjects, then there are explicit users or groups with access
|
||||
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
|
||||
user_emails.add(email)
|
||||
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
|
||||
group_names.add(group_name)
|
||||
else:
|
||||
# If there are no subjects, then the permission is for everyone
|
||||
if permission.get("operation", {}).get(
|
||||
"operation"
|
||||
) == "read" and permission.get("anonymousAccess", False):
|
||||
# If the permission specifies read access for anonymous users, then
|
||||
# the space is publicly accessible
|
||||
is_externally_public = True
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_externally_public,
|
||||
)
|
||||
|
||||
|
||||
def get_space_permission(
|
||||
confluence_client: OnyxConfluence,
|
||||
space_key: str,
|
||||
is_cloud: bool,
|
||||
) -> ExternalAccess:
|
||||
if is_cloud:
|
||||
space_permissions = _get_cloud_space_permissions(confluence_client, space_key)
|
||||
else:
|
||||
space_permissions = _get_server_space_permissions(confluence_client, space_key)
|
||||
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
and not space_permissions.external_user_emails
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"No permissions found for space '{space_key}'. This is very unlikely"
|
||||
"to be correct and is more likely caused by an access token with"
|
||||
"insufficient permissions. Make sure that the access token has Admin"
|
||||
f"permissions for space '{space_key}'"
|
||||
)
|
||||
|
||||
return space_permissions
|
||||
|
||||
|
||||
def get_all_space_permissions(
|
||||
confluence_client: OnyxConfluence,
|
||||
is_cloud: bool,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
logger.debug("Getting space permissions")
|
||||
# Gets all the spaces in the Confluence instance
|
||||
all_space_keys = []
|
||||
start = 0
|
||||
while True:
|
||||
spaces_batch = confluence_client.get_all_spaces(
|
||||
start=start, limit=REQUEST_PAGINATION_LIMIT
|
||||
)
|
||||
for space in spaces_batch.get("results", []):
|
||||
all_space_keys.append(space.get("key"))
|
||||
|
||||
if len(spaces_batch.get("results", [])) < REQUEST_PAGINATION_LIMIT:
|
||||
break
|
||||
|
||||
start += len(spaces_batch.get("results", []))
|
||||
|
||||
# Gets the permissions for each space
|
||||
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
|
||||
for space_key in all_space_keys:
|
||||
space_permissions = get_space_permission(confluence_client, space_key, is_cloud)
|
||||
|
||||
# Stores the permissions for each space
|
||||
space_permissions_by_space_key[space_key] = space_permissions
|
||||
|
||||
return space_permissions_by_space_key
|
||||
294
backend/ee/onyx/external_permissions/github/doc_sync.py
Normal file
294
backend/ee/onyx/external_permissions/github/doc_sync.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from github import Github
|
||||
from github.Repository import Repository
|
||||
|
||||
from ee.onyx.external_permissions.github.utils import fetch_repository_team_slugs
|
||||
from ee.onyx.external_permissions.github.utils import form_collaborators_group_id
|
||||
from ee.onyx.external_permissions.github.utils import form_organization_group_id
|
||||
from ee.onyx.external_permissions.github.utils import (
|
||||
form_outside_collaborators_group_id,
|
||||
)
|
||||
from ee.onyx.external_permissions.github.utils import get_external_access_permission
|
||||
from ee.onyx.external_permissions.github.utils import get_repository_visibility
|
||||
from ee.onyx.external_permissions.github.utils import GitHubVisibility
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.github.connector import DocMetadata
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
GITHUB_DOC_SYNC_LABEL = "github_doc_sync"
|
||||
|
||||
|
||||
def github_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Sync GitHub documents with external access permissions.
|
||||
|
||||
This function checks each repository for visibility/team changes and updates
|
||||
document permissions accordingly without using checkpoints.
|
||||
"""
|
||||
logger.info(f"Starting GitHub document sync for CC pair ID: {cc_pair.id}")
|
||||
|
||||
# Initialize GitHub connector with credentials
|
||||
github_connector: GithubConnector = GithubConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
logger.info("GitHub connector credentials loaded successfully")
|
||||
|
||||
if not github_connector.github_client:
|
||||
logger.error("GitHub client initialization failed")
|
||||
raise ValueError("github_client is required")
|
||||
|
||||
# Get all repositories from GitHub API
|
||||
logger.info("Fetching all repositories from GitHub API")
|
||||
try:
|
||||
repos = []
|
||||
if github_connector.repositories:
|
||||
if "," in github_connector.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = github_connector.get_github_repos(
|
||||
github_connector.github_client
|
||||
)
|
||||
else:
|
||||
# Single repository
|
||||
repos = [
|
||||
github_connector.get_github_repo(github_connector.github_client)
|
||||
]
|
||||
else:
|
||||
# All repositories
|
||||
repos = github_connector.get_all_repos(github_connector.github_client)
|
||||
|
||||
logger.info(f"Found {len(repos)} repositories to check")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch repositories: {e}")
|
||||
raise
|
||||
|
||||
repo_to_doc_list_map: dict[str, list[DocumentRow]] = {}
|
||||
# sort order is ascending because we want to get the oldest documents first
|
||||
existing_docs: list[DocumentRow] = fetch_all_existing_docs_fn(
|
||||
sort_order=SortOrder.ASC
|
||||
)
|
||||
logger.info(f"Found {len(existing_docs)} documents to check")
|
||||
for doc in existing_docs:
|
||||
try:
|
||||
doc_metadata = DocMetadata.model_validate_json(json.dumps(doc.doc_metadata))
|
||||
if doc_metadata.repo not in repo_to_doc_list_map:
|
||||
repo_to_doc_list_map[doc_metadata.repo] = []
|
||||
repo_to_doc_list_map[doc_metadata.repo].append(doc)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse doc metadata: {e} for doc {doc.id}")
|
||||
continue
|
||||
logger.info(f"Found {len(repo_to_doc_list_map)} documents to check")
|
||||
# Process each repository individually
|
||||
for repo in repos:
|
||||
try:
|
||||
logger.info(f"Processing repository: {repo.id} (name: {repo.name})")
|
||||
repo_doc_list: list[DocumentRow] = repo_to_doc_list_map.get(
|
||||
repo.full_name, []
|
||||
)
|
||||
if not repo_doc_list:
|
||||
logger.warning(
|
||||
f"No documents found for repository {repo.id} ({repo.name})"
|
||||
)
|
||||
continue
|
||||
|
||||
current_external_group_ids = repo_doc_list[0].external_user_group_ids or []
|
||||
# Check if repository has any permission changes
|
||||
has_changes = _check_repository_for_changes(
|
||||
repo=repo,
|
||||
github_client=github_connector.github_client,
|
||||
current_external_group_ids=current_external_group_ids,
|
||||
)
|
||||
|
||||
if has_changes:
|
||||
logger.info(
|
||||
f"Repository {repo.id} ({repo.name}) has changes, updating documents"
|
||||
)
|
||||
|
||||
# Get new external access permissions for this repository
|
||||
new_external_access = get_external_access_permission(
|
||||
repo, github_connector.github_client
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Found {len(repo_doc_list)} documents for repository {repo.full_name}"
|
||||
)
|
||||
|
||||
# Yield updated external access for each document
|
||||
for doc in repo_doc_list:
|
||||
if callback:
|
||||
callback.progress(GITHUB_DOC_SYNC_LABEL, 1)
|
||||
|
||||
yield DocExternalAccess(
|
||||
doc_id=doc.id,
|
||||
external_access=new_external_access,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Repository {repo.id} ({repo.name}) has no changes, skipping"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")
|
||||
|
||||
logger.info(f"GitHub document sync completed for CC pair ID: {cc_pair.id}")
|
||||
|
||||
|
||||
def _check_repository_for_changes(
|
||||
repo: Repository,
|
||||
github_client: Github,
|
||||
current_external_group_ids: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if repository has any permission changes (visibility or team updates).
|
||||
"""
|
||||
logger.info(f"Checking repository {repo.id} ({repo.name}) for changes")
|
||||
|
||||
# Check for repository visibility changes using the sample document data
|
||||
if _is_repo_visibility_changed_from_groups(
|
||||
repo=repo,
|
||||
current_external_group_ids=current_external_group_ids,
|
||||
):
|
||||
logger.info(f"Repository {repo.id} ({repo.name}) has visibility changes")
|
||||
return True
|
||||
|
||||
# Check for team membership changes if repository is private
|
||||
if get_repository_visibility(
|
||||
repo
|
||||
) == GitHubVisibility.PRIVATE and _teams_updated_from_groups(
|
||||
repo=repo,
|
||||
github_client=github_client,
|
||||
current_external_group_ids=current_external_group_ids,
|
||||
):
|
||||
logger.info(f"Repository {repo.id} ({repo.name}) has team changes")
|
||||
return True
|
||||
|
||||
logger.info(f"Repository {repo.id} ({repo.name}) has no changes")
|
||||
return False
|
||||
|
||||
|
||||
def _is_repo_visibility_changed_from_groups(
|
||||
repo: Repository,
|
||||
current_external_group_ids: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if repository visibility has changed by analyzing existing external group IDs.
|
||||
|
||||
Args:
|
||||
repo: GitHub repository object
|
||||
current_external_group_ids: List of external group IDs from existing document
|
||||
|
||||
Returns:
|
||||
True if visibility has changed
|
||||
"""
|
||||
current_repo_visibility = get_repository_visibility(repo)
|
||||
logger.info(f"Current repository visibility: {current_repo_visibility.value}")
|
||||
|
||||
# Build expected group IDs for current visibility
|
||||
collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_collaborators_group_id(repo.id),
|
||||
)
|
||||
|
||||
org_group_id = None
|
||||
if repo.organization:
|
||||
org_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_organization_group_id(repo.organization.id),
|
||||
)
|
||||
|
||||
# Determine existing visibility from group IDs
|
||||
has_collaborators_group = collaborators_group_id in current_external_group_ids
|
||||
has_org_group = org_group_id and org_group_id in current_external_group_ids
|
||||
|
||||
if has_collaborators_group:
|
||||
existing_repo_visibility = GitHubVisibility.PRIVATE
|
||||
elif has_org_group:
|
||||
existing_repo_visibility = GitHubVisibility.INTERNAL
|
||||
else:
|
||||
existing_repo_visibility = GitHubVisibility.PUBLIC
|
||||
|
||||
logger.info(f"Inferred existing visibility: {existing_repo_visibility.value}")
|
||||
|
||||
visibility_changed = existing_repo_visibility != current_repo_visibility
|
||||
if visibility_changed:
|
||||
logger.info(
|
||||
f"Visibility changed for repo {repo.id} ({repo.name}): "
|
||||
f"{existing_repo_visibility.value} -> {current_repo_visibility.value}"
|
||||
)
|
||||
|
||||
return visibility_changed
|
||||
|
||||
|
||||
def _teams_updated_from_groups(
|
||||
repo: Repository,
|
||||
github_client: Github,
|
||||
current_external_group_ids: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if repository team memberships have changed using existing group IDs.
|
||||
"""
|
||||
# Fetch current team slugs for the repository
|
||||
current_teams = fetch_repository_team_slugs(repo=repo, github_client=github_client)
|
||||
logger.info(
|
||||
f"Current teams for repository {repo.id} (name: {repo.name}): {current_teams}"
|
||||
)
|
||||
|
||||
# Build group IDs to exclude from team comparison (non-team groups)
|
||||
collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_collaborators_group_id(repo.id),
|
||||
)
|
||||
outside_collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_outside_collaborators_group_id(repo.id),
|
||||
)
|
||||
non_team_group_ids = {collaborators_group_id, outside_collaborators_group_id}
|
||||
|
||||
# Extract existing team IDs from current external group IDs
|
||||
existing_team_ids = set()
|
||||
for group_id in current_external_group_ids:
|
||||
# Skip all non-team groups, keep only team groups
|
||||
if group_id not in non_team_group_ids:
|
||||
existing_team_ids.add(group_id)
|
||||
|
||||
# Note: existing_team_ids from DB are already prefixed (e.g., "github__team-slug")
|
||||
# but current_teams from API are raw team slugs, so we need to add the prefix
|
||||
current_team_ids = set()
|
||||
for team_slug in current_teams:
|
||||
team_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=team_slug,
|
||||
)
|
||||
current_team_ids.add(team_group_id)
|
||||
|
||||
logger.info(
|
||||
f"Existing team IDs: {existing_team_ids}, Current team IDs: {current_team_ids}"
|
||||
)
|
||||
|
||||
# Compare actual team IDs to detect changes
|
||||
teams_changed = current_team_ids != existing_team_ids
|
||||
if teams_changed:
|
||||
logger.info(
|
||||
f"Team changes detected for repo {repo.id} (name: {repo.name}): "
|
||||
f"existing={existing_team_ids}, current={current_team_ids}"
|
||||
)
|
||||
|
||||
return teams_changed
|
||||
46
backend/ee/onyx/external_permissions/github/group_sync.py
Normal file
46
backend/ee/onyx/external_permissions/github/group_sync.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from github import Repository
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.github.utils import get_external_user_group
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def github_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
github_connector: GithubConnector = GithubConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
if not github_connector.github_client:
|
||||
raise ValueError("github_client is required")
|
||||
|
||||
logger.info("Starting GitHub group sync...")
|
||||
repos: list[Repository.Repository] = []
|
||||
if github_connector.repositories:
|
||||
if "," in github_connector.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = github_connector.get_github_repos(github_connector.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [github_connector.get_github_repo(github_connector.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = github_connector.get_all_repos(github_connector.github_client)
|
||||
|
||||
for repo in repos:
|
||||
try:
|
||||
for external_group in get_external_user_group(
|
||||
repo, github_connector.github_client
|
||||
):
|
||||
logger.info(f"External group: {external_group}")
|
||||
yield external_group
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")
|
||||
488
backend/ee/onyx/external_permissions/github/utils.py
Normal file
488
backend/ee/onyx/external_permissions/github/utils.py
Normal file
@@ -0,0 +1,488 @@
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
|
||||
from github import Github
|
||||
from github import RateLimitExceededException
|
||||
from github.GithubException import GithubException
|
||||
from github.NamedUser import NamedUser
|
||||
from github.Organization import Organization
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.Repository import Repository
|
||||
from github.Team import Team
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class GitHubVisibility(Enum):
|
||||
"""GitHub repository visibility options."""
|
||||
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
INTERNAL = "internal"
|
||||
|
||||
|
||||
MAX_RETRY_COUNT = 3
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Higher-order function to wrap GitHub operations with retry and exception handling
|
||||
|
||||
|
||||
def _run_with_retry(
|
||||
operation: Callable[[], T],
|
||||
description: str,
|
||||
github_client: Github,
|
||||
retry_count: int = 0,
|
||||
) -> Optional[T]:
|
||||
"""Execute a GitHub operation with retry on rate limit and exception handling."""
|
||||
logger.debug(f"Starting operation '{description}', attempt {retry_count + 1}")
|
||||
try:
|
||||
result = operation()
|
||||
logger.debug(f"Operation '{description}' completed successfully")
|
||||
return result
|
||||
except RateLimitExceededException:
|
||||
if retry_count < MAX_RETRY_COUNT:
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
logger.warning(
|
||||
f"Rate limit exceeded while {description}. Retrying... "
|
||||
f"(attempt {retry_count + 1}/{MAX_RETRY_COUNT})"
|
||||
)
|
||||
return _run_with_retry(
|
||||
operation, description, github_client, retry_count + 1
|
||||
)
|
||||
else:
|
||||
error_msg = f"Max retries exceeded for {description}"
|
||||
logger.exception(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
except GithubException as e:
|
||||
logger.warning(f"GitHub API error during {description}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error during {description}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
"""Represents a GitHub user with their basic information."""
|
||||
|
||||
login: str
|
||||
name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
|
||||
|
||||
class TeamInfo(BaseModel):
|
||||
"""Represents a GitHub team with its members."""
|
||||
|
||||
name: str
|
||||
slug: str
|
||||
members: List[UserInfo]
|
||||
|
||||
|
||||
def _fetch_organization_members(
|
||||
github_client: Github, org_name: str, retry_count: int = 0
|
||||
) -> List[UserInfo]:
|
||||
"""Fetch all organization members including owners and regular members."""
|
||||
org_members: List[UserInfo] = []
|
||||
logger.info(f"Fetching organization members for {org_name}")
|
||||
|
||||
org = _run_with_retry(
|
||||
lambda: github_client.get_organization(org_name),
|
||||
f"get organization {org_name}",
|
||||
github_client,
|
||||
)
|
||||
if not org:
|
||||
logger.error(f"Failed to fetch organization {org_name}")
|
||||
raise RuntimeError(f"Failed to fetch organization {org_name}")
|
||||
|
||||
member_objs: PaginatedList[NamedUser] | list[NamedUser] = (
|
||||
_run_with_retry(
|
||||
lambda: org.get_members(filter_="all"),
|
||||
f"get members for organization {org_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for member in member_objs:
|
||||
user_info = UserInfo(login=member.login, name=member.name, email=member.email)
|
||||
org_members.append(user_info)
|
||||
|
||||
logger.info(f"Fetched {len(org_members)} members for organization {org_name}")
|
||||
return org_members
|
||||
|
||||
|
||||
def _fetch_repository_teams_detailed(
|
||||
repo: Repository, github_client: Github, retry_count: int = 0
|
||||
) -> List[TeamInfo]:
|
||||
"""Fetch teams with access to the repository and their members."""
|
||||
teams_data: List[TeamInfo] = []
|
||||
logger.info(f"Fetching teams for repository {repo.full_name}")
|
||||
|
||||
team_objs: PaginatedList[Team] | list[Team] = (
|
||||
_run_with_retry(
|
||||
lambda: repo.get_teams(),
|
||||
f"get teams for repository {repo.full_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for team in team_objs:
|
||||
logger.info(
|
||||
f"Processing team {team.name} (slug: {team.slug}) for repository {repo.full_name}"
|
||||
)
|
||||
|
||||
members: PaginatedList[NamedUser] | list[NamedUser] = (
|
||||
_run_with_retry(
|
||||
lambda: team.get_members(),
|
||||
f"get members for team {team.name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
team_members = []
|
||||
for m in members:
|
||||
user_info = UserInfo(login=m.login, name=m.name, email=m.email)
|
||||
team_members.append(user_info)
|
||||
|
||||
team_info = TeamInfo(name=team.name, slug=team.slug, members=team_members)
|
||||
teams_data.append(team_info)
|
||||
logger.info(f"Team {team.name} has {len(team_members)} members")
|
||||
|
||||
logger.info(f"Fetched {len(teams_data)} teams for repository {repo.full_name}")
|
||||
return teams_data
|
||||
|
||||
|
||||
def fetch_repository_team_slugs(
|
||||
repo: Repository, github_client: Github, retry_count: int = 0
|
||||
) -> List[str]:
|
||||
"""Fetch team slugs with access to the repository."""
|
||||
logger.info(f"Fetching team slugs for repository {repo.full_name}")
|
||||
teams_data: List[str] = []
|
||||
|
||||
team_objs: PaginatedList[Team] | list[Team] = (
|
||||
_run_with_retry(
|
||||
lambda: repo.get_teams(),
|
||||
f"get teams for repository {repo.full_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for team in team_objs:
|
||||
teams_data.append(team.slug)
|
||||
|
||||
logger.info(f"Fetched {len(teams_data)} team slugs for repository {repo.full_name}")
|
||||
return teams_data
|
||||
|
||||
|
||||
def _get_collaborators_and_outside_collaborators(
|
||||
github_client: Github,
|
||||
repo: Repository,
|
||||
) -> Tuple[List[UserInfo], List[UserInfo]]:
|
||||
"""Fetch and categorize collaborators into regular and outside collaborators."""
|
||||
collaborators: List[UserInfo] = []
|
||||
outside_collaborators: List[UserInfo] = []
|
||||
logger.info(f"Fetching collaborators for repository {repo.full_name}")
|
||||
|
||||
repo_collaborators: PaginatedList[NamedUser] | list[NamedUser] = (
|
||||
_run_with_retry(
|
||||
lambda: repo.get_collaborators(),
|
||||
f"get collaborators for repository {repo.full_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for collaborator in repo_collaborators:
|
||||
is_outside = False
|
||||
|
||||
# Check if collaborator is outside the organization
|
||||
if repo.organization:
|
||||
org: Organization | None = _run_with_retry(
|
||||
lambda: github_client.get_organization(repo.organization.login),
|
||||
f"get organization {repo.organization.login}",
|
||||
github_client,
|
||||
)
|
||||
|
||||
if org is not None:
|
||||
org_obj = org
|
||||
membership = _run_with_retry(
|
||||
lambda: org_obj.has_in_members(collaborator),
|
||||
f"check membership for {collaborator.login} in org {org_obj.login}",
|
||||
github_client,
|
||||
)
|
||||
is_outside = membership is not None and not membership
|
||||
|
||||
info = UserInfo(
|
||||
login=collaborator.login, name=collaborator.name, email=collaborator.email
|
||||
)
|
||||
if repo.organization and is_outside:
|
||||
outside_collaborators.append(info)
|
||||
else:
|
||||
collaborators.append(info)
|
||||
|
||||
logger.info(
|
||||
f"Categorized {len(collaborators)} regular and {len(outside_collaborators)} outside collaborators for {repo.full_name}"
|
||||
)
|
||||
return collaborators, outside_collaborators
|
||||
|
||||
|
||||
def form_collaborators_group_id(repository_id: int) -> str:
|
||||
"""Generate group ID for repository collaborators."""
|
||||
if not repository_id:
|
||||
logger.exception("Repository ID is required to generate collaborators group ID")
|
||||
raise ValueError("Repository ID must be set to generate group ID.")
|
||||
group_id = f"{repository_id}_collaborators"
|
||||
return group_id
|
||||
|
||||
|
||||
def form_organization_group_id(organization_id: int) -> str:
|
||||
"""Generate group ID for organization using organization ID."""
|
||||
if not organization_id:
|
||||
logger.exception(
|
||||
"Organization ID is required to generate organization group ID"
|
||||
)
|
||||
raise ValueError("Organization ID must be set to generate group ID.")
|
||||
group_id = f"{organization_id}_organization"
|
||||
return group_id
|
||||
|
||||
|
||||
def form_outside_collaborators_group_id(repository_id: int) -> str:
|
||||
"""Generate group ID for outside collaborators."""
|
||||
if not repository_id:
|
||||
logger.exception(
|
||||
"Repository ID is required to generate outside collaborators group ID"
|
||||
)
|
||||
raise ValueError("Repository ID must be set to generate group ID.")
|
||||
group_id = f"{repository_id}_outside_collaborators"
|
||||
return group_id
|
||||
|
||||
|
||||
def get_repository_visibility(repo: Repository) -> GitHubVisibility:
|
||||
"""
|
||||
Get the visibility of a repository.
|
||||
Returns GitHubVisibility enum member.
|
||||
"""
|
||||
if hasattr(repo, "visibility"):
|
||||
visibility = repo.visibility
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} visibility from attribute: {visibility}"
|
||||
)
|
||||
try:
|
||||
return GitHubVisibility(visibility)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Unknown visibility '{visibility}' for repo {repo.full_name}, defaulting to private"
|
||||
)
|
||||
return GitHubVisibility.PRIVATE
|
||||
|
||||
logger.info(f"Repository {repo.full_name} is private")
|
||||
return GitHubVisibility.PRIVATE
|
||||
|
||||
|
||||
def get_external_access_permission(
|
||||
repo: Repository, github_client: Github, add_prefix: bool = False
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get the external access permission for a repository.
|
||||
Uses group-based permissions for efficiency and scalability.
|
||||
|
||||
add_prefix: When this method is called during the initial permission sync via the connector,
|
||||
the group ID isn't prefixed with the source while inserting the document record.
|
||||
So in that case, set add_prefix to True, allowing the method itself to handle
|
||||
prefixing. However, when the same method is invoked from doc_sync, our system
|
||||
already adds the prefix to the group ID while processing the ExternalAccess object.
|
||||
"""
|
||||
# We maintain collaborators, and outside collaborators as two separate groups
|
||||
# instead of adding individual user emails to ExternalAccess.external_user_emails for two reasons:
|
||||
# 1. Changes in repo collaborators (additions/removals) would require updating all documents.
|
||||
# 2. Repo permissions can change without updating the repo's updated_at timestamp,
|
||||
# forcing full permission syncs for all documents every time, which is inefficient.
|
||||
|
||||
repo_visibility = get_repository_visibility(repo)
|
||||
logger.info(
|
||||
f"Generating ExternalAccess for {repo.full_name}: visibility={repo_visibility.value}"
|
||||
)
|
||||
|
||||
if repo_visibility == GitHubVisibility.PUBLIC:
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} is public - allowing access to all users"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
elif repo_visibility == GitHubVisibility.PRIVATE:
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} is private - setting up restricted access"
|
||||
)
|
||||
|
||||
collaborators_group_id = form_collaborators_group_id(repo.id)
|
||||
outside_collaborators_group_id = form_outside_collaborators_group_id(repo.id)
|
||||
if add_prefix:
|
||||
collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=collaborators_group_id,
|
||||
)
|
||||
outside_collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=outside_collaborators_group_id,
|
||||
)
|
||||
group_ids = {collaborators_group_id, outside_collaborators_group_id}
|
||||
|
||||
team_slugs = fetch_repository_team_slugs(repo, github_client)
|
||||
if add_prefix:
|
||||
team_slugs = [
|
||||
build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=slug,
|
||||
)
|
||||
for slug in team_slugs
|
||||
]
|
||||
group_ids.update(team_slugs)
|
||||
|
||||
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=group_ids,
|
||||
is_public=False,
|
||||
)
|
||||
else:
|
||||
# Internal repositories - accessible to organization members
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} is internal - accessible to org members"
|
||||
)
|
||||
org_group_id = form_organization_group_id(repo.organization.id)
|
||||
if add_prefix:
|
||||
org_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=org_group_id,
|
||||
)
|
||||
group_ids = {org_group_id}
|
||||
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=group_ids,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def get_external_user_group(
|
||||
repo: Repository, github_client: Github
|
||||
) -> list[ExternalUserGroup]:
|
||||
"""
|
||||
Get the external user group for a repository.
|
||||
Creates ExternalUserGroup objects with actual user emails for each permission group.
|
||||
"""
|
||||
repo_visibility = get_repository_visibility(repo)
|
||||
logger.info(
|
||||
f"Generating ExternalUserGroups for {repo.full_name}: visibility={repo_visibility.value}"
|
||||
)
|
||||
|
||||
if repo_visibility == GitHubVisibility.PRIVATE:
|
||||
logger.info(f"Processing private repository {repo.full_name}")
|
||||
|
||||
collaborators, outside_collaborators = (
|
||||
_get_collaborators_and_outside_collaborators(github_client, repo)
|
||||
)
|
||||
teams = _fetch_repository_teams_detailed(repo, github_client)
|
||||
external_user_groups = []
|
||||
|
||||
user_emails = set()
|
||||
for collab in collaborators:
|
||||
if collab.email:
|
||||
user_emails.add(collab.email)
|
||||
else:
|
||||
logger.error(f"Collaborator {collab.login} has no email")
|
||||
|
||||
if user_emails:
|
||||
collaborators_group = ExternalUserGroup(
|
||||
id=form_collaborators_group_id(repo.id),
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
external_user_groups.append(collaborators_group)
|
||||
logger.info(f"Created collaborators group with {len(user_emails)} emails")
|
||||
|
||||
# Create group for outside collaborators
|
||||
user_emails = set()
|
||||
for collab in outside_collaborators:
|
||||
if collab.email:
|
||||
user_emails.add(collab.email)
|
||||
else:
|
||||
logger.error(f"Outside collaborator {collab.login} has no email")
|
||||
|
||||
if user_emails:
|
||||
outside_collaborators_group = ExternalUserGroup(
|
||||
id=form_outside_collaborators_group_id(repo.id),
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
external_user_groups.append(outside_collaborators_group)
|
||||
logger.info(
|
||||
f"Created outside collaborators group with {len(user_emails)} emails"
|
||||
)
|
||||
|
||||
# Create groups for teams
|
||||
for team in teams:
|
||||
user_emails = set()
|
||||
for member in team.members:
|
||||
if member.email:
|
||||
user_emails.add(member.email)
|
||||
else:
|
||||
logger.error(f"Team member {member.login} has no email")
|
||||
|
||||
if user_emails:
|
||||
team_group = ExternalUserGroup(
|
||||
id=team.slug,
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
external_user_groups.append(team_group)
|
||||
logger.info(
|
||||
f"Created team group {team.name} with {len(user_emails)} emails"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created {len(external_user_groups)} ExternalUserGroups for private repository {repo.full_name}"
|
||||
)
|
||||
return external_user_groups
|
||||
|
||||
if repo_visibility == GitHubVisibility.INTERNAL:
|
||||
logger.info(f"Processing internal repository {repo.full_name}")
|
||||
|
||||
org_group_id = form_organization_group_id(repo.organization.id)
|
||||
org_members = _fetch_organization_members(
|
||||
github_client, repo.organization.login
|
||||
)
|
||||
|
||||
user_emails = set()
|
||||
for member in org_members:
|
||||
if member.email:
|
||||
user_emails.add(member.email)
|
||||
else:
|
||||
logger.error(f"Org member {member.login} has no email")
|
||||
|
||||
org_group = ExternalUserGroup(
|
||||
id=org_group_id,
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
logger.info(
|
||||
f"Created organization group with {len(user_emails)} emails for internal repository {repo.full_name}"
|
||||
)
|
||||
return [org_group]
|
||||
|
||||
logger.info(f"Repository {repo.full_name} is public - no user groups needed")
|
||||
return []
|
||||
@@ -3,8 +3,8 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -36,6 +36,7 @@ def _get_slim_doc_generator(
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
@@ -59,17 +60,11 @@ def gmail_doc_sync(
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
if slim_doc.external_access is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
if user_email := slim_doc.perm_sync_data.get("user_email"):
|
||||
ext_access = ExternalAccess(
|
||||
external_user_emails=set([user_email]),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=slim_doc.external_access,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from ee.onyx.external_permissions.google_drive.models import PermissionType
|
||||
@@ -9,12 +8,13 @@ from ee.onyx.external_permissions.google_drive.permission_retrieval import (
|
||||
get_permissions_by_ids,
|
||||
)
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -41,57 +41,71 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
|
||||
def _fetch_permissions_for_permission_ids(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
permission_info: dict[str, Any],
|
||||
def _merge_permissions_lists(
|
||||
permission_lists: list[list[GoogleDrivePermission]],
|
||||
) -> list[GoogleDrivePermission]:
|
||||
doc_id = permission_info.get("doc_id")
|
||||
if not permission_info or not doc_id:
|
||||
return []
|
||||
"""
|
||||
Merge a list of permission lists into a single list of permissions.
|
||||
"""
|
||||
seen_permission_ids: set[str] = set()
|
||||
merged_permissions: list[GoogleDrivePermission] = []
|
||||
for permission_list in permission_lists:
|
||||
for permission in permission_list:
|
||||
if permission.id not in seen_permission_ids:
|
||||
merged_permissions.append(permission)
|
||||
seen_permission_ids.add(permission.id)
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
permission_ids = permission_info.get("permission_ids", [])
|
||||
if not permission_ids:
|
||||
return []
|
||||
|
||||
drive_service = get_drive_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
return get_permissions_by_ids(
|
||||
drive_service=drive_service,
|
||||
doc_id=doc_id,
|
||||
permission_ids=permission_ids,
|
||||
)
|
||||
return merged_permissions
|
||||
|
||||
|
||||
def _get_permissions_from_slim_doc(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
slim_doc: SlimDocument,
|
||||
def get_external_access_for_raw_gdrive_file(
|
||||
file: GoogleDriveFileType,
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
) -> ExternalAccess:
|
||||
permission_info = slim_doc.perm_sync_data or {}
|
||||
"""
|
||||
Get the external access for a raw Google Drive file.
|
||||
|
||||
Assumes the file we retrieved has EITHER `permissions` or `permission_ids`
|
||||
"""
|
||||
doc_id = file.get("id")
|
||||
if not doc_id:
|
||||
raise ValueError("No doc_id found in file")
|
||||
|
||||
permissions = file.get("permissions")
|
||||
permission_ids = file.get("permissionIds")
|
||||
drive_id = file.get("driveId")
|
||||
|
||||
permissions_list: list[GoogleDrivePermission] = []
|
||||
raw_permissions_list = permission_info.get("permissions", [])
|
||||
if not raw_permissions_list:
|
||||
permissions_list = _fetch_permissions_for_permission_ids(
|
||||
google_drive_connector=google_drive_connector,
|
||||
permission_info=permission_info,
|
||||
)
|
||||
if not permissions_list:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
else:
|
||||
if permissions:
|
||||
permissions_list = [
|
||||
GoogleDrivePermission.from_drive_permission(p) for p in raw_permissions_list
|
||||
GoogleDrivePermission.from_drive_permission(p) for p in permissions
|
||||
]
|
||||
elif permission_ids:
|
||||
|
||||
def _get_permissions(
|
||||
drive_service: GoogleDriveService,
|
||||
) -> list[GoogleDrivePermission]:
|
||||
return get_permissions_by_ids(
|
||||
drive_service=drive_service,
|
||||
doc_id=doc_id,
|
||||
permission_ids=permission_ids,
|
||||
)
|
||||
|
||||
permissions_list = _get_permissions(
|
||||
retriever_drive_service or admin_drive_service
|
||||
)
|
||||
if len(permissions_list) != len(permission_ids) and retriever_drive_service:
|
||||
logger.warning(
|
||||
f"Failed to get all permissions for file {doc_id} with retriever service, "
|
||||
"trying admin service"
|
||||
)
|
||||
backup_permissions_list = _get_permissions(admin_drive_service)
|
||||
permissions_list = _merge_permissions_lists(
|
||||
[permissions_list, backup_permissions_list]
|
||||
)
|
||||
|
||||
company_domain = google_drive_connector.google_domain
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
@@ -107,13 +121,8 @@ def _get_permissions_from_slim_doc(
|
||||
# We could fetch all ancestors of the file to get the list of folders that
|
||||
# might affect the permissions of the file, but this will get replaced with
|
||||
# an audit-log based approach in the future so not doing it now.
|
||||
if (
|
||||
permission.permission_details
|
||||
and permission.permission_details.inherited_from
|
||||
):
|
||||
folder_ids_to_inherit_permissions_from.add(
|
||||
permission.permission_details.inherited_from
|
||||
)
|
||||
if permission.inherited_from:
|
||||
folder_ids_to_inherit_permissions_from.add(permission.inherited_from)
|
||||
|
||||
if permission.type == PermissionType.USER:
|
||||
if permission.email_address:
|
||||
@@ -121,7 +130,7 @@ def _get_permissions_from_slim_doc(
|
||||
else:
|
||||
logger.error(
|
||||
"Permission is type `user` but no email address is "
|
||||
f"provided for document {slim_doc.id}"
|
||||
f"provided for document {doc_id}"
|
||||
f"\n {permission}"
|
||||
)
|
||||
elif permission.type == PermissionType.GROUP:
|
||||
@@ -131,7 +140,7 @@ def _get_permissions_from_slim_doc(
|
||||
else:
|
||||
logger.error(
|
||||
"Permission is type `group` but no email address is "
|
||||
f"provided for document {slim_doc.id}"
|
||||
f"provided for document {doc_id}"
|
||||
f"\n {permission}"
|
||||
)
|
||||
elif permission.type == PermissionType.DOMAIN and company_domain:
|
||||
@@ -145,7 +154,6 @@ def _get_permissions_from_slim_doc(
|
||||
elif permission.type == PermissionType.ANYONE:
|
||||
public = True
|
||||
|
||||
drive_id = permission_info.get("drive_id")
|
||||
group_ids = (
|
||||
group_emails
|
||||
| folder_ids_to_inherit_permissions_from
|
||||
@@ -162,6 +170,7 @@ def _get_permissions_from_slim_doc(
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
@@ -177,7 +186,9 @@ def gdrive_doc_sync(
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
|
||||
|
||||
total_processed = 0
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
logger.info(f"Drive perm sync: Processing {len(slim_doc_batch)} documents")
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
@@ -185,11 +196,14 @@ def gdrive_doc_sync(
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
)
|
||||
if slim_doc.external_access is None:
|
||||
raise ValueError(
|
||||
f"Drive perm sync: No external access for document {slim_doc.id}"
|
||||
)
|
||||
|
||||
yield DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
external_access=slim_doc.external_access,
|
||||
doc_id=slim_doc.id,
|
||||
)
|
||||
total_processed += len(slim_doc_batch)
|
||||
logger.info(f"Drive perm sync: Processed {total_processed} total documents")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -42,11 +44,17 @@ def _get_all_folders(
|
||||
|
||||
TODO: tweak things so we can fetch deltas.
|
||||
"""
|
||||
MAX_FAILED_PERCENTAGE = 0.5
|
||||
|
||||
all_folders: list[FolderInfo] = []
|
||||
seen_folder_ids: set[str] = set()
|
||||
|
||||
user_emails = google_drive_connector._get_all_user_emails()
|
||||
for user_email in user_emails:
|
||||
def _get_all_folders_for_user(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
skip_folders_without_permissions: bool,
|
||||
user_email: str,
|
||||
) -> None:
|
||||
"""Helper to get folders for a specific user + update shared seen_folder_ids"""
|
||||
drive_service = get_drive_service(
|
||||
google_drive_connector.creds,
|
||||
user_email,
|
||||
@@ -60,6 +68,8 @@ def _get_all_folders(
|
||||
logger.debug(f"Folder {folder_id} has already been seen. Skipping.")
|
||||
continue
|
||||
|
||||
seen_folder_ids.add(folder_id)
|
||||
|
||||
# Check if the folder has permission IDs but no permissions
|
||||
permission_ids = folder.get("permissionIds", [])
|
||||
raw_permissions = folder.get("permissions", [])
|
||||
@@ -75,7 +85,16 @@ def _get_all_folders(
|
||||
for permission in raw_permissions
|
||||
]
|
||||
|
||||
# Don't include inherited permissions, those will be captured
|
||||
# by the folder/shared drive itself
|
||||
permissions = [
|
||||
permission
|
||||
for permission in permissions
|
||||
if permission.inherited_from is None
|
||||
]
|
||||
|
||||
if not permissions and skip_folders_without_permissions:
|
||||
logger.debug(f"Folder {folder_id} has no permissions. Skipping.")
|
||||
continue
|
||||
|
||||
all_folders.append(
|
||||
@@ -84,11 +103,62 @@ def _get_all_folders(
|
||||
permissions=permissions,
|
||||
)
|
||||
)
|
||||
seen_folder_ids.add(folder_id)
|
||||
|
||||
failed_count = 0
|
||||
user_emails = google_drive_connector._get_all_user_emails()
|
||||
for user_email in user_emails:
|
||||
try:
|
||||
_get_all_folders_for_user(
|
||||
google_drive_connector, skip_folders_without_permissions, user_email
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Error getting folders for user {user_email}")
|
||||
failed_count += 1
|
||||
|
||||
if failed_count > MAX_FAILED_PERCENTAGE * len(user_emails):
|
||||
raise RuntimeError("Too many failed folder fetches during group sync")
|
||||
|
||||
return all_folders
|
||||
|
||||
|
||||
def _drive_folder_to_onyx_group(
|
||||
folder: FolderInfo,
|
||||
group_email_to_member_emails_map: dict[str, list[str]],
|
||||
) -> ExternalUserGroup:
|
||||
"""
|
||||
Converts a folder into an Onyx group.
|
||||
"""
|
||||
anyone_can_access = False
|
||||
folder_member_emails: set[str] = set()
|
||||
|
||||
for permission in folder.permissions:
|
||||
if permission.type == PermissionType.USER:
|
||||
if permission.email_address is None:
|
||||
logger.warning(
|
||||
f"User email is None for folder {folder.id} permission {permission}"
|
||||
)
|
||||
continue
|
||||
folder_member_emails.add(permission.email_address)
|
||||
elif permission.type == PermissionType.GROUP:
|
||||
if permission.email_address not in group_email_to_member_emails_map:
|
||||
logger.warning(
|
||||
f"Group email {permission.email_address} for folder {folder.id} "
|
||||
"not found in group_email_to_member_emails_map"
|
||||
)
|
||||
continue
|
||||
folder_member_emails.update(
|
||||
group_email_to_member_emails_map[permission.email_address]
|
||||
)
|
||||
elif permission.type == PermissionType.ANYONE:
|
||||
anyone_can_access = True
|
||||
|
||||
return ExternalUserGroup(
|
||||
id=folder.id,
|
||||
user_emails=list(folder_member_emails),
|
||||
gives_anyone_access=anyone_can_access,
|
||||
)
|
||||
|
||||
|
||||
"""Individual Shared Drive / My Drive Permission Sync"""
|
||||
|
||||
|
||||
@@ -157,7 +227,29 @@ def _get_drive_members(
|
||||
return drive_id_to_members_map
|
||||
|
||||
|
||||
def _get_all_groups(
|
||||
def _drive_member_map_to_onyx_groups(
|
||||
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
|
||||
group_email_to_member_emails_map: dict[str, list[str]],
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""The `user_emails` for the Shared Drive should be all individuals in the
|
||||
Shared Drive + the union of all flattened group emails."""
|
||||
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
|
||||
drive_member_emails: set[str] = user_emails
|
||||
for group_email in group_emails:
|
||||
if group_email not in group_email_to_member_emails_map:
|
||||
logger.warning(
|
||||
f"Group email {group_email} for drive {drive_id} not found in "
|
||||
"group_email_to_member_emails_map"
|
||||
)
|
||||
continue
|
||||
drive_member_emails.update(group_email_to_member_emails_map[group_email])
|
||||
yield ExternalUserGroup(
|
||||
id=drive_id,
|
||||
user_emails=list(drive_member_emails),
|
||||
)
|
||||
|
||||
|
||||
def _get_all_google_groups(
|
||||
admin_service: AdminService,
|
||||
google_domain: str,
|
||||
) -> set[str]:
|
||||
@@ -175,6 +267,28 @@ def _get_all_groups(
|
||||
return group_emails
|
||||
|
||||
|
||||
def _google_group_to_onyx_group(
|
||||
admin_service: AdminService,
|
||||
group_email: str,
|
||||
) -> ExternalUserGroup:
|
||||
"""
|
||||
This maps google group emails to their member emails.
|
||||
"""
|
||||
group_member_emails: set[str] = set()
|
||||
for member in execute_paginated_retrieval(
|
||||
admin_service.members().list,
|
||||
list_key="members",
|
||||
groupKey=group_email,
|
||||
fields="members(email),nextPageToken",
|
||||
):
|
||||
group_member_emails.add(member["email"])
|
||||
|
||||
return ExternalUserGroup(
|
||||
id=group_email,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
|
||||
|
||||
def _map_group_email_to_member_emails(
|
||||
admin_service: AdminService,
|
||||
group_emails: set[str],
|
||||
@@ -272,7 +386,7 @@ def _build_onyx_groups(
|
||||
def gdrive_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
# Initialize connector and build credential/service objects
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
@@ -286,26 +400,27 @@ def gdrive_group_sync(
|
||||
drive_id_to_members_map = _get_drive_members(google_drive_connector, admin_service)
|
||||
|
||||
# Get all group emails
|
||||
all_group_emails = _get_all_groups(
|
||||
all_group_emails = _get_all_google_groups(
|
||||
admin_service, google_drive_connector.google_domain
|
||||
)
|
||||
|
||||
# Each google group is an Onyx group, yield those
|
||||
group_email_to_member_emails_map: dict[str, list[str]] = {}
|
||||
for group_email in all_group_emails:
|
||||
onyx_group = _google_group_to_onyx_group(admin_service, group_email)
|
||||
group_email_to_member_emails_map[group_email] = onyx_group.user_emails
|
||||
yield onyx_group
|
||||
|
||||
# Each drive is a group, yield those
|
||||
for onyx_group in _drive_member_map_to_onyx_groups(
|
||||
drive_id_to_members_map, group_email_to_member_emails_map
|
||||
):
|
||||
yield onyx_group
|
||||
|
||||
# Get all folder permissions
|
||||
folder_info = _get_all_folders(
|
||||
google_drive_connector=google_drive_connector,
|
||||
skip_folders_without_permissions=True,
|
||||
)
|
||||
|
||||
# Map group emails to their members
|
||||
group_email_to_member_emails_map = _map_group_email_to_member_emails(
|
||||
admin_service, all_group_emails
|
||||
)
|
||||
|
||||
# Convert the maps to onyx groups
|
||||
onyx_groups = _build_onyx_groups(
|
||||
drive_id_to_members_map=drive_id_to_members_map,
|
||||
group_email_to_member_emails_map=group_email_to_member_emails_map,
|
||||
folder_info=folder_info,
|
||||
)
|
||||
|
||||
return onyx_groups
|
||||
for folder in folder_info:
|
||||
yield _drive_folder_to_onyx_group(folder, group_email_to_member_emails_map)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user