mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 00:05:47 +00:00
Compare commits
267 Commits
double_ini
...
helm-reado
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66779d5692 | ||
|
|
42572479cb | ||
|
|
accd363d3f | ||
|
|
8cf754a8b6 | ||
|
|
bf79220ac0 | ||
|
|
4c9dc14e65 | ||
|
|
f8621f7ea9 | ||
|
|
e0e08427b9 | ||
|
|
169df994da | ||
|
|
d83eaf2efb | ||
|
|
4e1e30f751 | ||
|
|
561f8c9c53 | ||
|
|
f625a4d0a7 | ||
|
|
746d4b6d3c | ||
|
|
fdd48c6588 | ||
|
|
23a04f7b9c | ||
|
|
b7b0dde7aa | ||
|
|
c40b78c7e9 | ||
|
|
33c0133cc7 | ||
|
|
cca5bc13dc | ||
|
|
d5ecaea8e7 | ||
|
|
b6d3c38ca9 | ||
|
|
b5fc1b4323 | ||
|
|
a1a9c42b0b | ||
|
|
e689e143e5 | ||
|
|
a7a168d934 | ||
|
|
69f47fc3e3 | ||
|
|
8a87140b4f | ||
|
|
53db0ddc4d | ||
|
|
087085403f | ||
|
|
c040b1cb47 | ||
|
|
1f4d0716b9 | ||
|
|
aa4993873f | ||
|
|
ce031c4394 | ||
|
|
d4dadb0dda | ||
|
|
0ded5813cd | ||
|
|
83137a19fb | ||
|
|
be66f8dbeb | ||
|
|
70baecb402 | ||
|
|
c27ba6bad4 | ||
|
|
61fda6ec58 | ||
|
|
2c93841eaa | ||
|
|
879db3391b | ||
|
|
6dff3d41fa | ||
|
|
e399eeb014 | ||
|
|
b133e8fcf0 | ||
|
|
ba9b24a477 | ||
|
|
cc7fb625a6 | ||
|
|
2b812b7d7d | ||
|
|
c5adbe4180 | ||
|
|
21dc3a2456 | ||
|
|
9631f373f0 | ||
|
|
cbd4d46fa5 | ||
|
|
dc4b9bc003 | ||
|
|
affb9e6941 | ||
|
|
dc542fd7fa | ||
|
|
85eeb21b77 | ||
|
|
4bb3ee03a0 | ||
|
|
1bb23d6837 | ||
|
|
f447359815 | ||
|
|
851e0b05f2 | ||
|
|
094cc940a4 | ||
|
|
51be9000bb | ||
|
|
80ecdb711d | ||
|
|
a599176bbf | ||
|
|
e0341b4c8a | ||
|
|
4c93fd448f | ||
|
|
84d916e210 | ||
|
|
f57ed2a8dd | ||
|
|
713889babf | ||
|
|
58c641d8ec | ||
|
|
94985e24c6 | ||
|
|
4c71a5f5ff | ||
|
|
b19e3a500b | ||
|
|
267fe027f5 | ||
|
|
0d4d8c0d64 | ||
|
|
6f9d8c0cff | ||
|
|
5031096a2b | ||
|
|
797e113000 | ||
|
|
edc2892785 | ||
|
|
ef4d5dcec3 | ||
|
|
0b5e3e5ee4 | ||
|
|
f5afb3621e | ||
|
|
9f72826143 | ||
|
|
ab7a4184df | ||
|
|
16a14bac89 | ||
|
|
baaf31513c | ||
|
|
0b01d7f848 | ||
|
|
23ff3476bc | ||
|
|
0c7ba8e2ac | ||
|
|
dad99cbec7 | ||
|
|
3e78c2f087 | ||
|
|
e822afdcfa | ||
|
|
b824951c89 | ||
|
|
ca20e527fc | ||
|
|
c8e65cce1e | ||
|
|
6c349687da | ||
|
|
3b64793d4b | ||
|
|
9dbe12cea8 | ||
|
|
e78637d632 | ||
|
|
cac03c07f7 | ||
|
|
95dabfaa18 | ||
|
|
e92c418e0f | ||
|
|
0593d045bf | ||
|
|
fff701b0bb | ||
|
|
0087a32d8b | ||
|
|
06312e485c | ||
|
|
e0f5b95cfc | ||
|
|
10bc072b4b | ||
|
|
b60884d3af | ||
|
|
95ae6d300c | ||
|
|
b76e4754bf | ||
|
|
b1085039ca | ||
|
|
d64f479c9f | ||
|
|
fd735c9a3f | ||
|
|
2282f6a42e | ||
|
|
0262002883 | ||
|
|
01ca9dc85d | ||
|
|
0735a98284 | ||
|
|
8d2e170fc4 | ||
|
|
f3e2795e69 | ||
|
|
30d9ce1310 | ||
|
|
2af2b7f130 | ||
|
|
9d41820363 | ||
|
|
a44f289aed | ||
|
|
9c078b3acf | ||
|
|
349f2c6ed6 | ||
|
|
0dc851a1cf | ||
|
|
f27fe068e8 | ||
|
|
f836cff935 | ||
|
|
312e3b92bc | ||
|
|
0cc0964231 | ||
|
|
b82278e685 | ||
|
|
daa1746b4a | ||
|
|
d8068f0a68 | ||
|
|
d91f776c2d | ||
|
|
a01135581f | ||
|
|
392b87fb4f | ||
|
|
551a05aef0 | ||
|
|
6b9d0b5af9 | ||
|
|
b8f3ad3e5d | ||
|
|
b19515e25d | ||
|
|
913f7cc7d4 | ||
|
|
84566debab | ||
|
|
1a8b7abd00 | ||
|
|
4c0423f27b | ||
|
|
7965fd9cbb | ||
|
|
91831f4d07 | ||
|
|
1dd98a87cc | ||
|
|
0dd65cc839 | ||
|
|
519aeb6a1f | ||
|
|
0eab6ab935 | ||
|
|
ee09cb95af | ||
|
|
8a9a66947e | ||
|
|
d744c0dab4 | ||
|
|
70df685709 | ||
|
|
f85ef78238 | ||
|
|
2d7e48d8e8 | ||
|
|
8231328dc6 | ||
|
|
7763e2fa23 | ||
|
|
6085bff12d | ||
|
|
97d60a89ae | ||
|
|
79b981075e | ||
|
|
113876b276 | ||
|
|
5c3820b39f | ||
|
|
55e4465782 | ||
|
|
6d9693dc51 | ||
|
|
75fa10cead | ||
|
|
0497bfdf78 | ||
|
|
0db2ad2132 | ||
|
|
49cd38fb2d | ||
|
|
bd36b2ad6d | ||
|
|
6436b60763 | ||
|
|
a6cc1c84dc | ||
|
|
8515f4b57a | ||
|
|
f68b74ff4a | ||
|
|
e254fdc066 | ||
|
|
f26de37878 | ||
|
|
94de23fe87 | ||
|
|
dd242c9926 | ||
|
|
47767c1666 | ||
|
|
9be3da2357 | ||
|
|
8961a3cc72 | ||
|
|
47b9e7aa62 | ||
|
|
eebfa5be18 | ||
|
|
5047d256b4 | ||
|
|
5db676967f | ||
|
|
ea0664e203 | ||
|
|
bbd0874200 | ||
|
|
c6d100b415 | ||
|
|
5ca7a7def9 | ||
|
|
92b5e1adf4 | ||
|
|
23c6e0f3bf | ||
|
|
ad76e6ac9e | ||
|
|
151aabea73 | ||
|
|
d711680069 | ||
|
|
9835d55ecb | ||
|
|
69c539df6e | ||
|
|
df67ca18d8 | ||
|
|
115cfb6ae9 | ||
|
|
672f3a1c34 | ||
|
|
13b71f559f | ||
|
|
2981b7a425 | ||
|
|
37adf31a3b | ||
|
|
5d59850a17 | ||
|
|
91d6b739a4 | ||
|
|
c83ee06062 | ||
|
|
c93cebe1ab | ||
|
|
ea1d3c1eda | ||
|
|
c9a609b7d8 | ||
|
|
07f04e35ec | ||
|
|
d8b050026d | ||
|
|
c76dc2ea2c | ||
|
|
5e11c635d9 | ||
|
|
669b668463 | ||
|
|
85fa083717 | ||
|
|
420d2614d4 | ||
|
|
e3218d358d | ||
|
|
ae632b5fab | ||
|
|
0d4c600852 | ||
|
|
eb569bf79d | ||
|
|
f3d5303d93 | ||
|
|
b97628070e | ||
|
|
72d3a7ff21 | ||
|
|
2111eccf07 | ||
|
|
87478c5ca6 | ||
|
|
dc62d83a06 | ||
|
|
5681df9095 | ||
|
|
6666300f37 | ||
|
|
7f99c54527 | ||
|
|
4b8ef4b151 | ||
|
|
e5e0944049 | ||
|
|
356336a842 | ||
|
|
5bc059881e | ||
|
|
fa80842afe | ||
|
|
a8a5a82251 | ||
|
|
953a4e3793 | ||
|
|
04ebde7838 | ||
|
|
6df1c6c72f | ||
|
|
fe94bdf936 | ||
|
|
2a9fd9342e | ||
|
|
597ad806e3 | ||
|
|
5acae2dc80 | ||
|
|
99455db26c | ||
|
|
0d12e96362 | ||
|
|
7e7b6e08ff | ||
|
|
1dd32ebfce | ||
|
|
c3ffaa19a4 | ||
|
|
f4ea7e62a7 | ||
|
|
2ac41c3719 | ||
|
|
a8cba7abae | ||
|
|
ae9f8c3071 | ||
|
|
742041d97a | ||
|
|
187b93275d | ||
|
|
ca2aeac2cc | ||
|
|
f7543c6285 | ||
|
|
1430a18d44 | ||
|
|
7c4487585d | ||
|
|
e572ce95e7 | ||
|
|
68c6c1f4f8 | ||
|
|
a5edc8aa0f | ||
|
|
a377f6ffb6 | ||
|
|
72ce2f75cc | ||
|
|
2683207a24 | ||
|
|
e3aab8e85e | ||
|
|
65fd8b90a8 | ||
|
|
6eaa774051 |
@@ -25,6 +25,10 @@ inputs:
|
||||
tags:
|
||||
description: 'Image tags'
|
||||
required: true
|
||||
no-cache:
|
||||
description: 'Read from cache'
|
||||
required: false
|
||||
default: 'false'
|
||||
cache-from:
|
||||
description: 'Cache sources'
|
||||
required: false
|
||||
@@ -55,6 +59,7 @@ runs:
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
@@ -77,6 +82,7 @@ runs:
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
@@ -99,6 +105,7 @@ runs:
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
|
||||
@@ -7,18 +7,47 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
|
||||
- run-id=${{ github.run_id }}
|
||||
- tag=platform-${{ matrix.platform }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -34,18 +63,80 @@ jobs:
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
id: build
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: backend-digests-${{ env.PLATFORM_PAIR }}-${{ github.run_id }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- build-and-push
|
||||
steps:
|
||||
# Needed for trivyignore
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: backend-digests-*-${{ github.run_id }}
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
@@ -56,6 +147,8 @@ jobs:
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL onyxdotapp/onyx-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- "*cloud*"
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
DEPLOYMENT: cloud
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
@@ -38,9 +38,10 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -53,7 +54,7 @@ jobs:
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
@@ -70,10 +71,12 @@ jobs:
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/cloudweb-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/cloudweb-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
# no-cache needed due to weird interactions with the builds for different platforms
|
||||
# NOTE(rkuo): this may not be true any more with the proper cache prefixing by architecture - currently testing with it off
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
@@ -84,7 +87,7 @@ jobs:
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-${{ env.PLATFORM_PAIR }}
|
||||
name: cloudweb-digests-${{ env.PLATFORM_PAIR }}-${{ github.run_id }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
@@ -98,7 +101,7 @@ jobs:
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
pattern: cloudweb-digests-*-${{ github.run_id }}
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
@@ -109,6 +112,10 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
@@ -136,6 +143,8 @@ jobs:
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
@@ -7,10 +7,13 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDKIT_PROGRESS: plain
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
|
||||
jobs:
|
||||
|
||||
# Bypassing this for now as the idea of not building is glitching
|
||||
@@ -51,6 +54,8 @@ jobs:
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
|
||||
env:
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -75,7 +80,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and Push AMD64
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
@@ -86,12 +91,17 @@ jobs:
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
# no-cache: true
|
||||
|
||||
build-arm64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
env:
|
||||
PLATFORM_PAIR: linux-arm64
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -116,7 +126,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and Push ARM64
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
@@ -127,6 +137,8 @@ jobs:
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
merge-and-scan:
|
||||
needs: [build-amd64, build-arm64, check_model_server_changes]
|
||||
@@ -156,6 +168,8 @@ jobs:
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
@@ -8,9 +8,25 @@ on:
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
DEPLOYMENT: standalone
|
||||
|
||||
jobs:
|
||||
precheck:
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
outputs:
|
||||
should-run: ${{ steps.set-output.outputs.should-run }}
|
||||
steps:
|
||||
- name: Check if tag contains "cloud"
|
||||
id: set-output
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" == *cloud* ]]; then
|
||||
echo "should-run=false" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "should-run=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
build:
|
||||
needs: precheck
|
||||
if: needs.precheck.outputs.should-run == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
|
||||
@@ -37,9 +53,11 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -52,7 +70,7 @@ jobs:
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
@@ -62,11 +80,13 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/web-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/web-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
# no-cache needed due to weird interactions with the builds for different platforms
|
||||
# NOTE(rkuo): this may not be true any more with the proper cache prefixing by architecture - currently testing with it off
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
@@ -76,21 +96,22 @@ jobs:
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-${{ env.PLATFORM_PAIR }}
|
||||
name: web-digests-${{ env.PLATFORM_PAIR }}-${{ github.run_id }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- build
|
||||
if: needs.precheck.outputs.should-run == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
pattern: web-digests-*-${{ github.run_id }}
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
@@ -101,6 +122,11 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
@@ -128,6 +154,8 @@ jobs:
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
5
.github/workflows/pr-helm-chart-testing.yml
vendored
5
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -37,6 +37,11 @@ jobs:
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# uncomment to force run chart-testing
|
||||
# - name: Force run chart-testing (list-changed)
|
||||
# id: list-changed
|
||||
# run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
59
.github/workflows/pr-integration-tests.yml
vendored
59
.github/workflows/pr-integration-tests.yml
vendored
@@ -16,15 +16,55 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/ee.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/ee.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -61,8 +101,8 @@ jobs:
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
@@ -73,8 +113,8 @@ jobs:
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
@@ -85,8 +125,8 @@ jobs:
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# Start containers for multi-tenant tests
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
@@ -113,6 +153,8 @@ jobs:
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
@@ -158,6 +200,7 @@ jobs:
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
@@ -210,6 +253,8 @@ jobs:
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
|
||||
55
.github/workflows/pr-mit-integration-tests.yml
vendored
55
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -16,15 +16,52 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
jobs:
|
||||
integration-tests-mit:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -61,8 +98,8 @@ jobs:
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
@@ -73,8 +110,8 @@ jobs:
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
@@ -85,8 +122,8 @@ jobs:
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
@@ -152,6 +189,8 @@ jobs:
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
|
||||
1
.github/workflows/pr-playwright-tests.yml
vendored
1
.github/workflows/pr-playwright-tests.yml
vendored
@@ -10,6 +10,7 @@ env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MOCK_LLM_RESPONSE: true
|
||||
PYTEST_PLAYWRIGHT_SKIP_INITIAL_RESET: true
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
|
||||
23
.github/workflows/pr-python-checks.yml
vendored
23
.github/workflows/pr-python-checks.yml
vendored
@@ -31,16 +31,29 @@ jobs:
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client
|
||||
|
||||
- name: Run MyPy
|
||||
run: |
|
||||
cd backend
|
||||
mypy .
|
||||
|
||||
- name: Run ruff
|
||||
run: |
|
||||
cd backend
|
||||
ruff .
|
||||
|
||||
- name: Check import order with reorder-python-imports
|
||||
run: |
|
||||
cd backend
|
||||
|
||||
34
.github/workflows/pr-python-connector-tests.yml
vendored
34
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -12,7 +12,7 @@ env:
|
||||
# AWS
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
@@ -20,10 +20,12 @@ env:
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
# Jira
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
|
||||
# Gong
|
||||
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
|
||||
GONG_ACCESS_KEY_SECRET: ${{ secrets.GONG_ACCESS_KEY_SECRET }}
|
||||
|
||||
@@ -33,37 +35,57 @@ env:
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
# Zendesk
|
||||
ZENDESK_SUBDOMAIN: ${{ secrets.ZENDESK_SUBDOMAIN }}
|
||||
ZENDESK_EMAIL: ${{ secrets.ZENDESK_EMAIL }}
|
||||
ZENDESK_TOKEN: ${{ secrets.ZENDESK_TOKEN }}
|
||||
|
||||
# Salesforce
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
|
||||
# Gitlab
|
||||
GITLAB_ACCESS_TOKEN: ${{ secrets.GITLAB_ACCESS_TOKEN }}
|
||||
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
|
||||
# Notion
|
||||
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
|
||||
|
||||
# Highspot
|
||||
HIGHSPOT_KEY: ${{ secrets.HIGHSPOT_KEY }}
|
||||
HIGHSPOT_SECRET: ${{ secrets.HIGHSPOT_SECRET }}
|
||||
|
||||
# Slack
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
|
||||
# Teams
|
||||
TEAMS_APPLICATION_ID: ${{ secrets.TEAMS_APPLICATION_ID }}
|
||||
TEAMS_DIRECTORY_ID: ${{ secrets.TEAMS_DIRECTORY_ID }}
|
||||
TEAMS_SECRET: ${{ secrets.TEAMS_SECRET }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
@@ -95,7 +117,15 @@ jobs:
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
|
||||
run: |
|
||||
py.test \
|
||||
-n 8 \
|
||||
--dist loadfile \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/daily/connectors
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
|
||||
3
.github/workflows/pr-python-tests.yml
vendored
3
.github/workflows/pr-python-tests.yml
vendored
@@ -15,6 +15,9 @@ jobs:
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -14,6 +14,9 @@
|
||||
/web/test-results/
|
||||
backend/onyx/agent_search/main/test_data.json
|
||||
backend/tests/regression/answer_quality/test_data.json
|
||||
backend/tests/regression/search_quality/eval-*
|
||||
backend/tests/regression/search_quality/search_eval_config.yaml
|
||||
backend/tests/regression/search_quality/*.json
|
||||
|
||||
# secret files
|
||||
.env
|
||||
|
||||
40
.vscode/launch.template.jsonc
vendored
40
.vscode/launch.template.jsonc
vendored
@@ -412,6 +412,46 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
// script to generate the openapi schema
|
||||
"name": "Onyx OpenAPI Schema Generator",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/onyx_openapi_schema.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"--filename",
|
||||
"generated/openapi.json",
|
||||
]
|
||||
},
|
||||
{
|
||||
// script to debug multi tenant db issues
|
||||
"name": "Onyx DB Manager (Top Chunks)",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/debugging/onyx_db.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"--password",
|
||||
"your_password_here",
|
||||
"--port",
|
||||
"5433",
|
||||
"--report",
|
||||
"top-chunks",
|
||||
"--filename",
|
||||
"generated/tenants_by_num_docs.csv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Debug React Web App in Chrome",
|
||||
"type": "chrome",
|
||||
|
||||
101
.vscode/tasks.template.jsonc
vendored
Normal file
101
.vscode/tasks.template.jsonc
vendored
Normal file
@@ -0,0 +1,101 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "austin",
|
||||
"label": "Profile celery beat",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/backend"
|
||||
},
|
||||
"command": [
|
||||
"sudo",
|
||||
"-E"
|
||||
],
|
||||
"args": [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "Generate Onyx OpenAPI Python client",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/backend"
|
||||
},
|
||||
"command": [
|
||||
"openapi-generator"
|
||||
],
|
||||
"args": [
|
||||
"generate",
|
||||
"-i",
|
||||
"generated/openapi.json",
|
||||
"-g",
|
||||
"python",
|
||||
"-o",
|
||||
"generated/onyx_openapi_client",
|
||||
"--package-name",
|
||||
"onyx_openapi_client",
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "Generate Typescript Fetch client (openapi-generator)",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"command": [
|
||||
"openapi-generator"
|
||||
],
|
||||
"args": [
|
||||
"generate",
|
||||
"-i",
|
||||
"backend/generated/openapi.json",
|
||||
"-g",
|
||||
"typescript-fetch",
|
||||
"-o",
|
||||
"${workspaceFolder}/web/src/lib/generated/onyx_api",
|
||||
"--additional-properties=disallowAdditionalPropertiesIfNotPresent=false,legacyDiscriminatorBehavior=false,supportsES6=true",
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "Generate TypeScript Client (openapi-ts)",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/web"
|
||||
},
|
||||
"command": [
|
||||
"npx"
|
||||
],
|
||||
"args": [
|
||||
"openapi-typescript",
|
||||
"../backend/generated/openapi.json",
|
||||
"--output",
|
||||
"./src/lib/generated/onyx-schema.ts",
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "Generate TypeScript Client (orval)",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/web"
|
||||
},
|
||||
"command": [
|
||||
"npx"
|
||||
],
|
||||
"args": [
|
||||
"orval",
|
||||
"--config",
|
||||
"orval.config.js",
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md"} -->
|
||||
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md"} -->
|
||||
|
||||
# Contributing to Onyx
|
||||
|
||||
@@ -12,8 +12,8 @@ As an open source project in a rapidly changing space, we welcome all contributi
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to Hagen (or any other maintainer) on the Onyx team
|
||||
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
|
||||
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
@@ -28,7 +28,7 @@ Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
|
||||
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
|
||||
|
||||
<a name="readme-top"></a>
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
|
||||
4
backend/.gitignore
vendored
4
backend/.gitignore
vendored
@@ -9,4 +9,6 @@ api_keys.py
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule*
|
||||
onyx/connectors/salesforce/data/
|
||||
onyx/connectors/salesforce/data/
|
||||
.test.env
|
||||
/generated
|
||||
|
||||
@@ -37,7 +37,8 @@ RUN apt-get update && \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
vim \
|
||||
postgresql-client && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -85,7 +86,7 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('punkt', quiet=True);"
|
||||
nltk.download('punkt_tab', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Set up application files
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/backend/alembic/README.md"} -->
|
||||
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/backend/alembic/README.md"} -->
|
||||
|
||||
# Alembic DB Migrations
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from onyx.db.engine import SqlEngine
|
||||
|
||||
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
|
||||
# hidden! (defaults to level=WARN)
|
||||
@@ -147,6 +148,9 @@ async def run_async_migrations() -> None:
|
||||
continue_on_error,
|
||||
) = get_schema_options()
|
||||
|
||||
# without init_engine, subsequent engine calls fail hard intentionally
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
engine = create_async_engine(
|
||||
build_connection_string(),
|
||||
poolclass=pool.NullPool,
|
||||
@@ -180,10 +184,10 @@ async def run_async_migrations() -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
logger.error("--continue is not set, raising exception!")
|
||||
logger.error("--continue=true is not set, raising exception!")
|
||||
raise
|
||||
|
||||
logger.warning("--continue is set, continuing to next schema.")
|
||||
logger.warning("--continue=true is set, continuing to next schema.")
|
||||
|
||||
else:
|
||||
try:
|
||||
@@ -202,10 +206,21 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""This doesn't really get used when we migrate in the cloud."""
|
||||
"""
|
||||
NOTE(rkuo): This generates a sql script that can be used to migrate the database ...
|
||||
instead of migrating the db live via an open connection
|
||||
|
||||
Not clear on when this would be used by us or if it even works.
|
||||
|
||||
If it is offline, then why are there calls to the db engine?
|
||||
|
||||
This doesn't really get used when we migrate in the cloud."""
|
||||
|
||||
logger.info("run_migrations_offline starting.")
|
||||
|
||||
# without init_engine, subsequent engine calls fail hard intentionally
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
|
||||
121
backend/alembic/versions/03bf8be6b53a_rework_kg_config.py
Normal file
121
backend/alembic/versions/03bf8be6b53a_rework_kg_config.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""rework-kg-config
|
||||
|
||||
Revision ID: 03bf8be6b53a
|
||||
Revises: 65bc6e0f8500
|
||||
Create Date: 2025-06-16 10:52:34.815335
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import text
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03bf8be6b53a"
|
||||
down_revision = "65bc6e0f8500"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# get current config
|
||||
current_configs = (
|
||||
op.get_bind()
|
||||
.execute(text("SELECT kg_variable_name, kg_variable_values FROM kg_config"))
|
||||
.all()
|
||||
)
|
||||
current_config_dict = {
|
||||
config.kg_variable_name: (
|
||||
config.kg_variable_values[0]
|
||||
if config.kg_variable_name
|
||||
not in ("KG_VENDOR_DOMAINS", "KG_IGNORE_EMAIL_DOMAINS")
|
||||
else config.kg_variable_values
|
||||
)
|
||||
for config in current_configs
|
||||
if config.kg_variable_values
|
||||
}
|
||||
|
||||
# not using the KGConfigSettings model here in case it changes in the future
|
||||
kg_config_settings = json.dumps(
|
||||
{
|
||||
"KG_EXPOSED": current_config_dict.get("KG_EXPOSED", False),
|
||||
"KG_ENABLED": current_config_dict.get("KG_ENABLED", False),
|
||||
"KG_VENDOR": current_config_dict.get("KG_VENDOR", None),
|
||||
"KG_VENDOR_DOMAINS": current_config_dict.get("KG_VENDOR_DOMAINS", []),
|
||||
"KG_IGNORE_EMAIL_DOMAINS": current_config_dict.get(
|
||||
"KG_IGNORE_EMAIL_DOMAINS", []
|
||||
),
|
||||
"KG_COVERAGE_START": current_config_dict.get(
|
||||
"KG_COVERAGE_START",
|
||||
(datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d"),
|
||||
),
|
||||
"KG_MAX_COVERAGE_DAYS": current_config_dict.get("KG_MAX_COVERAGE_DAYS", 90),
|
||||
"KG_MAX_PARENT_RECURSION_DEPTH": current_config_dict.get(
|
||||
"KG_MAX_PARENT_RECURSION_DEPTH", 2
|
||||
),
|
||||
"KG_BETA_PERSONA_ID": current_config_dict.get("KG_BETA_PERSONA_ID", None),
|
||||
}
|
||||
)
|
||||
op.execute(
|
||||
f"INSERT INTO key_value_store (key, value) VALUES ('kg_config', '{kg_config_settings}')"
|
||||
)
|
||||
|
||||
# drop kg config table
|
||||
op.drop_table("kg_config")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# get current config
|
||||
current_config_dict = {
|
||||
"KG_EXPOSED": False,
|
||||
"KG_ENABLED": False,
|
||||
"KG_VENDOR": [],
|
||||
"KG_VENDOR_DOMAINS": [],
|
||||
"KG_IGNORE_EMAIL_DOMAINS": [],
|
||||
"KG_COVERAGE_START": (datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d"),
|
||||
"KG_MAX_COVERAGE_DAYS": 90,
|
||||
"KG_MAX_PARENT_RECURSION_DEPTH": 2,
|
||||
}
|
||||
current_configs = (
|
||||
op.get_bind()
|
||||
.execute(text("SELECT value FROM key_value_store WHERE key = 'kg_config'"))
|
||||
.one_or_none()
|
||||
)
|
||||
if current_configs is not None:
|
||||
current_config_dict.update(current_configs[0])
|
||||
insert_values = [
|
||||
{
|
||||
"kg_variable_name": name,
|
||||
"kg_variable_values": (
|
||||
[str(val).lower() if isinstance(val, bool) else str(val)]
|
||||
if not isinstance(val, list)
|
||||
else val
|
||||
),
|
||||
}
|
||||
for name, val in current_config_dict.items()
|
||||
]
|
||||
|
||||
op.create_table(
|
||||
"kg_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("kg_variable_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False),
|
||||
sa.UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"),
|
||||
)
|
||||
op.bulk_insert(
|
||||
sa.table(
|
||||
"kg_config",
|
||||
sa.column("kg_variable_name", sa.String),
|
||||
sa.column("kg_variable_values", postgresql.ARRAY(sa.String)),
|
||||
),
|
||||
insert_values,
|
||||
)
|
||||
|
||||
op.execute("DELETE FROM key_value_store WHERE key = 'kg_config'")
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Add foreign key to user__external_user_group_id
|
||||
|
||||
Revision ID: 238b84885828
|
||||
Revises: a7688ab35c45
|
||||
Create Date: 2025-05-19 17:15:33.424584
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "238b84885828"
|
||||
down_revision = "a7688ab35c45"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First, clean up any entries that don't have a valid cc_pair_id
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM user__external_user_group_id
|
||||
WHERE cc_pair_id NOT IN (SELECT id FROM connector_credential_pair)
|
||||
"""
|
||||
)
|
||||
|
||||
# Add foreign key constraint with cascade delete
|
||||
op.create_foreign_key(
|
||||
"fk_user__external_user_group_id_cc_pair_id",
|
||||
"user__external_user_group_id",
|
||||
"connector_credential_pair",
|
||||
["cc_pair_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_user__external_user_group_id_cc_pair_id",
|
||||
"user__external_user_group_id",
|
||||
type_="foreignkey",
|
||||
)
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Fix invalid model-configurations state
|
||||
|
||||
Revision ID: 47a07e1a38f1
|
||||
Revises: 7a70b7664e37
|
||||
Create Date: 2025-04-23 15:39:43.159504
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.llm.llm_provider_options import (
|
||||
fetch_model_names_for_provider_as_set,
|
||||
fetch_visible_model_names_for_provider_as_set,
|
||||
)
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "47a07e1a38f1"
|
||||
down_revision = "7a70b7664e37"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
class _SimpleModelConfiguration(BaseModel):
|
||||
# Configure model to read from attributes
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
llm_provider_id: int
|
||||
name: str
|
||||
is_visible: bool
|
||||
max_input_tokens: int | None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
llm_provider_table = sa.sql.table(
|
||||
"llm_provider",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("provider", sa.String),
|
||||
sa.column("model_names", postgresql.ARRAY(sa.String)),
|
||||
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
|
||||
sa.column("default_model_name", sa.String),
|
||||
sa.column("fast_default_model_name", sa.String),
|
||||
)
|
||||
model_configuration_table = sa.sql.table(
|
||||
"model_configuration",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("llm_provider_id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_visible", sa.Boolean),
|
||||
sa.column("max_input_tokens", sa.Integer),
|
||||
)
|
||||
|
||||
connection = op.get_bind()
|
||||
|
||||
llm_providers = connection.execute(
|
||||
sa.select(
|
||||
llm_provider_table.c.id,
|
||||
llm_provider_table.c.provider,
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for llm_provider in llm_providers:
|
||||
llm_provider_id, provider_name = llm_provider
|
||||
|
||||
default_models = fetch_model_names_for_provider_as_set(provider_name)
|
||||
display_models = fetch_visible_model_names_for_provider_as_set(
|
||||
provider_name=provider_name
|
||||
)
|
||||
|
||||
# if `fetch_model_names_for_provider_as_set` returns `None`, then
|
||||
# that means that `provider_name` is not a well-known llm provider.
|
||||
if not default_models:
|
||||
continue
|
||||
|
||||
if not display_models:
|
||||
raise RuntimeError(
|
||||
"If `default_models` is non-None, `display_models` must be non-None too."
|
||||
)
|
||||
|
||||
model_configurations = [
|
||||
_SimpleModelConfiguration.model_validate(model_configuration)
|
||||
for model_configuration in connection.execute(
|
||||
sa.select(
|
||||
model_configuration_table.c.id,
|
||||
model_configuration_table.c.llm_provider_id,
|
||||
model_configuration_table.c.name,
|
||||
model_configuration_table.c.is_visible,
|
||||
model_configuration_table.c.max_input_tokens,
|
||||
).where(model_configuration_table.c.llm_provider_id == llm_provider_id)
|
||||
).fetchall()
|
||||
]
|
||||
|
||||
if model_configurations:
|
||||
at_least_one_is_visible = any(
|
||||
[
|
||||
model_configuration.is_visible
|
||||
for model_configuration in model_configurations
|
||||
]
|
||||
)
|
||||
|
||||
# If there is at least one model which is public, this is a valid state.
|
||||
# Therefore, don't touch it and move on to the next one.
|
||||
if at_least_one_is_visible:
|
||||
continue
|
||||
|
||||
existing_visible_model_names: set[str] = set(
|
||||
[
|
||||
model_configuration.name
|
||||
for model_configuration in model_configurations
|
||||
if model_configuration.is_visible
|
||||
]
|
||||
)
|
||||
|
||||
difference = display_models.difference(existing_visible_model_names)
|
||||
|
||||
for model_name in difference:
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
insert_statement = postgresql.insert(model_configuration_table).values(
|
||||
llm_provider_id=llm_provider_id,
|
||||
name=model_name,
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
insert_statement.on_conflict_do_update(
|
||||
index_elements=["llm_provider_id", "name"],
|
||||
set_={"is_visible": insert_statement.excluded.is_visible},
|
||||
)
|
||||
)
|
||||
else:
|
||||
for model_name in default_models:
|
||||
connection.execute(
|
||||
model_configuration_table.insert().values(
|
||||
llm_provider_id=llm_provider_id,
|
||||
name=model_name,
|
||||
is_visible=model_name in display_models,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,682 @@
|
||||
"""create knowledge graph tables
|
||||
|
||||
Revision ID: 495cb26ce93e
|
||||
Revises: ca04500b9ee8
|
||||
Create Date: 2025-03-19 08:51:14.341989
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "495cb26ce93e"
|
||||
down_revision = "ca04500b9ee8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
# Create a new permission-less user to be later used for knowledge graph queries.
|
||||
# The user will later get temporary read privileges for a specific view that will be
|
||||
# ad hoc generated specific to a knowledge graph query.
|
||||
#
|
||||
# Note: in order for the migration to run, the DB_READONLY_USER and DB_READONLY_PASSWORD
|
||||
# environment variables MUST be set. Otherwise, an exception will be raised.
|
||||
|
||||
if not MULTI_TENANT:
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
# Create read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is created in the alembic_tenants migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Grant usage on current schema to readonly user
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
EXECUTE format('GRANT USAGE ON SCHEMA %I TO %I', current_schema(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"kg_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("kg_variable_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False),
|
||||
sa.UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"),
|
||||
)
|
||||
|
||||
# Insert initial data into kg_config table
|
||||
op.bulk_insert(
|
||||
sa.table(
|
||||
"kg_config",
|
||||
sa.column("kg_variable_name", sa.String),
|
||||
sa.column("kg_variable_values", postgresql.ARRAY(sa.String)),
|
||||
),
|
||||
[
|
||||
{"kg_variable_name": "KG_EXPOSED", "kg_variable_values": ["false"]},
|
||||
{"kg_variable_name": "KG_ENABLED", "kg_variable_values": ["false"]},
|
||||
{"kg_variable_name": "KG_VENDOR", "kg_variable_values": []},
|
||||
{"kg_variable_name": "KG_VENDOR_DOMAINS", "kg_variable_values": []},
|
||||
{"kg_variable_name": "KG_IGNORE_EMAIL_DOMAINS", "kg_variable_values": []},
|
||||
{
|
||||
"kg_variable_name": "KG_EXTRACTION_IN_PROGRESS",
|
||||
"kg_variable_values": ["false"],
|
||||
},
|
||||
{
|
||||
"kg_variable_name": "KG_CLUSTERING_IN_PROGRESS",
|
||||
"kg_variable_values": ["false"],
|
||||
},
|
||||
{
|
||||
"kg_variable_name": "KG_COVERAGE_START",
|
||||
"kg_variable_values": [
|
||||
(datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d")
|
||||
],
|
||||
},
|
||||
{"kg_variable_name": "KG_MAX_COVERAGE_DAYS", "kg_variable_values": ["90"]},
|
||||
{
|
||||
"kg_variable_name": "KG_MAX_PARENT_RECURSION_DEPTH",
|
||||
"kg_variable_values": ["2"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"kg_entity_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column("grounding", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"attributes",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("deep_extraction", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column("grounded_source_name", sa.String(), nullable=True),
|
||||
sa.Column("entity_values", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
)
|
||||
|
||||
# Create KGRelationshipType table
|
||||
op.create_table(
|
||||
"kg_relationship_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"source_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"target_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
)
|
||||
|
||||
# Create KGRelationshipTypeExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_relationship_type_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"source_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"target_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("transferred", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
)
|
||||
|
||||
# Create KGEntity table
|
||||
op.create_table(
|
||||
"kg_entity",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("entity_class", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_key", sa.String(), nullable=True, index=True),
|
||||
sa.Column("name_trigrams", postgresql.ARRAY(sa.String(3)), nullable=True),
|
||||
sa.Column("document_id", sa.String(), nullable=True, index=True),
|
||||
sa.Column(
|
||||
"alternative_names",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"keywords",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column(
|
||||
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
|
||||
sa.UniqueConstraint(
|
||||
"name",
|
||||
"entity_type_id_name",
|
||||
"document_id",
|
||||
name="uq_kg_entity_name_type_doc",
|
||||
),
|
||||
)
|
||||
op.create_index("ix_entity_type_acl", "kg_entity", ["entity_type_id_name", "acl"])
|
||||
op.create_index(
|
||||
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
|
||||
)
|
||||
|
||||
# Create KGEntityExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_entity_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("document_id", sa.String(), nullable=True, index=True),
|
||||
sa.Column(
|
||||
"alternative_names",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"keywords",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column(
|
||||
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("transferred_id_name", sa.String(), nullable=True, default=None),
|
||||
sa.Column("entity_class", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_key", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
|
||||
sa.Column("parent_key", sa.String(), nullable=True, index=True),
|
||||
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_entity_extraction_staging_acl",
|
||||
"kg_entity_extraction_staging",
|
||||
["entity_type_id_name", "acl"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_entity_extraction_staging_name_search",
|
||||
"kg_entity_extraction_staging",
|
||||
["name", "entity_type_id_name"],
|
||||
)
|
||||
|
||||
# Create KGRelationship table
|
||||
op.create_table(
|
||||
"kg_relationship",
|
||||
sa.Column("id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_document", sa.String(), nullable=True, index=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["source_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
|
||||
sa.ForeignKeyConstraint(
|
||||
["relationship_type_id_name"], ["kg_relationship_type.id_name"]
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id_name", "source_document"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
|
||||
)
|
||||
|
||||
# Create KGRelationshipExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_relationship_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_document", sa.String(), nullable=True, index=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("transferred", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_node"], ["kg_entity_extraction_staging.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_node"], ["kg_entity_extraction_staging.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
|
||||
sa.ForeignKeyConstraint(
|
||||
["relationship_type_id_name"],
|
||||
["kg_relationship_type_extraction_staging.id_name"],
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_extraction_staging_source_target_type",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id_name", "source_document"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kg_relationship_extraction_staging_nodes",
|
||||
"kg_relationship_extraction_staging",
|
||||
["source_node", "target_node"],
|
||||
)
|
||||
|
||||
# Create KGTerm table
|
||||
op.create_table(
|
||||
"kg_term",
|
||||
sa.Column("id_term", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column(
|
||||
"entity_types",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
)
|
||||
op.create_index("ix_search_term_entities", "kg_term", ["entity_types"])
|
||||
op.create_index("ix_search_term_term", "kg_term", ["id_term"])
|
||||
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_stage", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_processing_time", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column(
|
||||
"kg_processing_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column(
|
||||
"kg_coverage_days",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
server_default=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Create GIN index for clustering and normalization
|
||||
op.execute(
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_clustering_trigrams "
|
||||
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.gin_trgm_ops)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_normalization_trigrams "
|
||||
"ON kg_entity USING GIN (name_trigrams)"
|
||||
)
|
||||
|
||||
# Create kg_entity trigger to update kg_entity.name and its trigrams
|
||||
alphanum_pattern = r"[^a-z0-9]+"
|
||||
truncate_length = 1000
|
||||
function = "update_kg_entity_name"
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE OR REPLACE FUNCTION {function}()
|
||||
RETURNS TRIGGER AS $$
|
||||
DECLARE
|
||||
name text;
|
||||
cleaned_name text;
|
||||
BEGIN
|
||||
-- Set name to semantic_id if document_id is not NULL
|
||||
IF NEW.document_id IS NOT NULL THEN
|
||||
SELECT lower(semantic_id) INTO name
|
||||
FROM document
|
||||
WHERE id = NEW.document_id;
|
||||
ELSE
|
||||
name = lower(NEW.name);
|
||||
END IF;
|
||||
|
||||
-- Clean name and truncate if too long
|
||||
cleaned_name = regexp_replace(
|
||||
name,
|
||||
'{alphanum_pattern}', '', 'g'
|
||||
);
|
||||
IF length(cleaned_name) > {truncate_length} THEN
|
||||
cleaned_name = left(cleaned_name, {truncate_length});
|
||||
END IF;
|
||||
|
||||
-- Set name and name trigrams
|
||||
NEW.name = name;
|
||||
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
trigger = f"{function}_trigger"
|
||||
op.execute(f"DROP TRIGGER IF EXISTS {trigger} ON kg_entity")
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TRIGGER {trigger}
|
||||
BEFORE INSERT OR UPDATE OF name
|
||||
ON kg_entity
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION {function}();
|
||||
"""
|
||||
)
|
||||
|
||||
# Create kg_entity trigger to update kg_entity.name and its trigrams
|
||||
function = "update_kg_entity_name_from_doc"
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE OR REPLACE FUNCTION {function}()
|
||||
RETURNS TRIGGER AS $$
|
||||
DECLARE
|
||||
doc_name text;
|
||||
cleaned_name text;
|
||||
BEGIN
|
||||
doc_name = lower(NEW.semantic_id);
|
||||
|
||||
-- Clean name and truncate if too long
|
||||
cleaned_name = regexp_replace(
|
||||
doc_name,
|
||||
'{alphanum_pattern}', '', 'g'
|
||||
);
|
||||
IF length(cleaned_name) > {truncate_length} THEN
|
||||
cleaned_name = left(cleaned_name, {truncate_length});
|
||||
END IF;
|
||||
|
||||
-- Set name and name trigrams for all entities referencing this document
|
||||
UPDATE kg_entity
|
||||
SET
|
||||
name = doc_name,
|
||||
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
|
||||
WHERE document_id = NEW.id;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
trigger = f"{function}_trigger"
|
||||
op.execute(f"DROP TRIGGER IF EXISTS {trigger} ON document")
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TRIGGER {trigger}
|
||||
AFTER UPDATE OF semantic_id
|
||||
ON document
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION {function}();
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
# Drop all views that start with 'kg_'
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
view_name text;
|
||||
BEGIN
|
||||
FOR view_name IN
|
||||
SELECT c.relname
|
||||
FROM pg_catalog.pg_class c
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind = 'v'
|
||||
AND n.nspname = current_schema()
|
||||
AND c.relname LIKE 'kg_relationships_with_access%'
|
||||
LOOP
|
||||
EXECUTE 'DROP VIEW IF EXISTS ' || quote_ident(view_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
view_name text;
|
||||
BEGIN
|
||||
FOR view_name IN
|
||||
SELECT c.relname
|
||||
FROM pg_catalog.pg_class c
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind = 'v'
|
||||
AND n.nspname = current_schema()
|
||||
AND c.relname LIKE 'allowed_docs%'
|
||||
LOOP
|
||||
EXECUTE 'DROP VIEW IF EXISTS ' || quote_ident(view_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
for table, function in (
|
||||
("kg_entity", "update_kg_entity_name"),
|
||||
("document", "update_kg_entity_name_from_doc"),
|
||||
):
|
||||
op.execute(f"DROP TRIGGER IF EXISTS {function}_trigger ON {table}")
|
||||
op.execute(f"DROP FUNCTION IF EXISTS {function}()")
|
||||
|
||||
# Drop index
|
||||
op.execute("COMMIT") # Commit to allow CONCURRENTLY
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_clustering_trigrams")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_normalization_trigrams")
|
||||
|
||||
# Drop tables in reverse order of creation to handle dependencies
|
||||
op.drop_table("kg_term")
|
||||
op.drop_table("kg_relationship")
|
||||
op.drop_table("kg_entity")
|
||||
op.drop_table("kg_relationship_type")
|
||||
op.drop_table("kg_relationship_extraction_staging")
|
||||
op.drop_table("kg_relationship_type_extraction_staging")
|
||||
op.drop_table("kg_entity_extraction_staging")
|
||||
op.drop_table("kg_entity_type")
|
||||
op.drop_column("connector", "kg_processing_enabled")
|
||||
op.drop_column("connector", "kg_coverage_days")
|
||||
op.drop_column("document", "kg_stage")
|
||||
op.drop_column("document", "kg_processing_time")
|
||||
op.drop_table("kg_config")
|
||||
|
||||
# Revoke usage on current schema for the readonly user
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA %I FROM %I', current_schema(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if not MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Add content type to UserFile
|
||||
|
||||
Revision ID: 5c448911b12f
|
||||
Revises: 47a07e1a38f1
|
||||
Create Date: 2025-04-25 16:59:48.182672
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5c448911b12f"
|
||||
down_revision = "47a07e1a38f1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("user_file", sa.Column("content_type", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_file", "content_type")
|
||||
@@ -0,0 +1,41 @@
|
||||
"""remove kg subtype from db
|
||||
|
||||
Revision ID: 65bc6e0f8500
|
||||
Revises: cec7ec36c505
|
||||
Create Date: 2025-06-13 10:04:27.705976
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "65bc6e0f8500"
|
||||
down_revision = "cec7ec36c505"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("kg_entity", "entity_class")
|
||||
op.drop_column("kg_entity", "entity_subtype")
|
||||
op.drop_column("kg_entity_extraction_staging", "entity_class")
|
||||
op.drop_column("kg_entity_extraction_staging", "entity_subtype")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"kg_entity_extraction_staging",
|
||||
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
op.add_column(
|
||||
"kg_entity_extraction_staging",
|
||||
sa.Column("entity_class", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
op.add_column(
|
||||
"kg_entity", sa.Column("entity_subtype", sa.String(), nullable=True, index=True)
|
||||
)
|
||||
op.add_column(
|
||||
"kg_entity", sa.Column("entity_class", sa.String(), nullable=True, index=True)
|
||||
)
|
||||
@@ -6,12 +6,6 @@ Create Date: 2025-04-01 07:26:10.539362
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
import datetime
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6a804aeb4830"
|
||||
down_revision = "8e1ac4f39a9f"
|
||||
@@ -19,99 +13,10 @@ branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Leaving this around only because some people might be on this migration
|
||||
# originally was a duplicate of the user files migration
|
||||
def upgrade() -> None:
|
||||
# Check if user_file table already exists
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
|
||||
if not inspector.has_table("user_file"):
|
||||
# Create user_folder table without parent_id
|
||||
op.create_table(
|
||||
"user_folder",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column("name", sa.String(length=255), nullable=True),
|
||||
sa.Column("description", sa.String(length=255), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create user_file table with folder_id instead of parent_folder_id
|
||||
op.create_table(
|
||||
"user_file",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column(
|
||||
"folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("link_url", sa.String(), nullable=True),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("file_type", sa.String(), nullable=True),
|
||||
sa.Column("file_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("document_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
default=datetime.datetime.utcnow,
|
||||
),
|
||||
sa.Column(
|
||||
"cc_pair_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("connector_credential_pair.id"),
|
||||
nullable=True,
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_file table
|
||||
op.create_table(
|
||||
"persona__user_file",
|
||||
sa.Column(
|
||||
"persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"user_file_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_file.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_folder table
|
||||
op.create_table(
|
||||
"persona__user_folder",
|
||||
sa.Column(
|
||||
"persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"user_folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
|
||||
)
|
||||
|
||||
# Update existing records to have is_user_file=False instead of NULL
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -6,11 +6,8 @@ Create Date: 2024-04-15 01:36:02.952809
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "703313b75876"
|
||||
@@ -54,27 +51,10 @@ def upgrade() -> None:
|
||||
sa.PrimaryKeyConstraint("rate_limit_id", "user_group_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
settings_json = cast(str, get_kv_store().load("token_budget_settings"))
|
||||
settings = json.loads(settings_json)
|
||||
|
||||
is_enabled = settings.get("enable_token_budget", False)
|
||||
token_budget = settings.get("token_budget", -1)
|
||||
period_hours = settings.get("period_hours", -1)
|
||||
|
||||
if is_enabled and token_budget > 0 and period_hours > 0:
|
||||
op.execute(
|
||||
f"INSERT INTO token_rate_limit \
|
||||
(enabled, token_budget, period_hours, scope) VALUES \
|
||||
({is_enabled}, {token_budget}, {period_hours}, 'GLOBAL')"
|
||||
)
|
||||
|
||||
# Delete the dynamic config
|
||||
get_kv_store().delete("token_budget_settings")
|
||||
|
||||
except Exception:
|
||||
# Ignore if the dynamic config is not found
|
||||
pass
|
||||
# NOTE: rate limit settings used to be stored in the "token_budget_settings" key in the
|
||||
# KeyValueStore. This will now be lost. The KV store works differently than it used to
|
||||
# so the migration is fairly complicated and likely not worth it to support (pretty much
|
||||
# nobody will have it set)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
"""Add model-configuration table
|
||||
|
||||
Revision ID: 7a70b7664e37
|
||||
Revises: d961aca62eb3
|
||||
Create Date: 2025-04-10 15:00:35.984669
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.llm.llm_provider_options import (
|
||||
fetch_model_names_for_provider_as_set,
|
||||
fetch_visible_model_names_for_provider_as_set,
|
||||
)
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7a70b7664e37"
|
||||
down_revision = "d961aca62eb3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _resolve(
|
||||
provider_name: str,
|
||||
model_names: list[str] | None,
|
||||
display_model_names: list[str] | None,
|
||||
default_model_name: str,
|
||||
fast_default_model_name: str | None,
|
||||
) -> set[tuple[str, bool]]:
|
||||
models = set(model_names) if model_names else None
|
||||
display_models = set(display_model_names) if display_model_names else None
|
||||
|
||||
# If both are defined, we need to make sure that `model_names` is a superset of `display_model_names`.
|
||||
if models and display_models:
|
||||
models = display_models.union(models)
|
||||
|
||||
# If only `model_names` is defined, then:
|
||||
# - If default-model-names are available for the `provider_name`, then set `display_model_names` to it
|
||||
# and set `model_names` to the union of those default-model-names with itself.
|
||||
# - If no default-model-names are available, then set `display_models` to `models`.
|
||||
#
|
||||
# This preserves the invariant that `display_models` is a subset of `models`.
|
||||
elif models and not display_models:
|
||||
visible_default_models = fetch_visible_model_names_for_provider_as_set(
|
||||
provider_name=provider_name
|
||||
)
|
||||
if visible_default_models:
|
||||
display_models = set(visible_default_models)
|
||||
models = display_models.union(models)
|
||||
else:
|
||||
display_models = set(models)
|
||||
|
||||
# If only the `display_model_names` are defined, then set `models` to the union of `display_model_names`
|
||||
# and the default-model-names for that provider.
|
||||
#
|
||||
# This will also preserve the invariant that `display_models` is a subset of `models`.
|
||||
elif not models and display_models:
|
||||
default_models = fetch_model_names_for_provider_as_set(
|
||||
provider_name=provider_name
|
||||
)
|
||||
if default_models:
|
||||
models = display_models.union(default_models)
|
||||
else:
|
||||
models = set(display_models)
|
||||
|
||||
# If neither are defined, then set `models` and `display_models` to the default-model-names for the given provider.
|
||||
#
|
||||
# This will also preserve the invariant that `display_models` is a subset of `models`.
|
||||
else:
|
||||
default_models = fetch_model_names_for_provider_as_set(
|
||||
provider_name=provider_name
|
||||
)
|
||||
visible_default_models = fetch_visible_model_names_for_provider_as_set(
|
||||
provider_name=provider_name
|
||||
)
|
||||
|
||||
if default_models:
|
||||
if not visible_default_models:
|
||||
raise RuntimeError
|
||||
raise RuntimeError(
|
||||
"If `default_models` is non-None, `visible_default_models` must be non-None too."
|
||||
)
|
||||
models = default_models
|
||||
display_models = visible_default_models
|
||||
|
||||
# This is not a well-known llm-provider; we can't provide any model suggestions.
|
||||
# Therefore, we set to the empty set and continue
|
||||
else:
|
||||
models = set()
|
||||
display_models = set()
|
||||
|
||||
# It is possible that `default_model_name` is not in `models` and is not in `display_models`.
|
||||
# It is also possible that `fast_default_model_name` is not in `models` and is not in `display_models`.
|
||||
models.add(default_model_name)
|
||||
if fast_default_model_name:
|
||||
models.add(fast_default_model_name)
|
||||
display_models.add(default_model_name)
|
||||
if fast_default_model_name:
|
||||
display_models.add(fast_default_model_name)
|
||||
|
||||
return set([(model, model in display_models) for model in models])
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"model_configuration",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("is_visible", sa.Boolean(), nullable=False),
|
||||
sa.Column("max_input_tokens", sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["llm_provider_id"], ["llm_provider.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("llm_provider_id", "name"),
|
||||
)
|
||||
|
||||
# Create temporary sqlalchemy references to tables for data migration
|
||||
llm_provider_table = sa.sql.table(
|
||||
"llm_provider",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("provider", sa.Integer),
|
||||
sa.column("model_names", postgresql.ARRAY(sa.String)),
|
||||
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
|
||||
sa.column("default_model_name", sa.String),
|
||||
sa.column("fast_default_model_name", sa.String),
|
||||
)
|
||||
model_configuration_table = sa.sql.table(
|
||||
"model_configuration",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("llm_provider_id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_visible", sa.Boolean),
|
||||
sa.column("max_input_tokens", sa.Integer),
|
||||
)
|
||||
connection = op.get_bind()
|
||||
llm_providers = connection.execute(
|
||||
sa.select(
|
||||
llm_provider_table.c.id,
|
||||
llm_provider_table.c.provider,
|
||||
llm_provider_table.c.model_names,
|
||||
llm_provider_table.c.display_model_names,
|
||||
llm_provider_table.c.default_model_name,
|
||||
llm_provider_table.c.fast_default_model_name,
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for llm_provider in llm_providers:
|
||||
provider_id = llm_provider[0]
|
||||
provider_name = llm_provider[1]
|
||||
model_names = llm_provider[2]
|
||||
display_model_names = llm_provider[3]
|
||||
default_model_name = llm_provider[4]
|
||||
fast_default_model_name = llm_provider[5]
|
||||
|
||||
model_configurations = _resolve(
|
||||
provider_name=provider_name,
|
||||
model_names=model_names,
|
||||
display_model_names=display_model_names,
|
||||
default_model_name=default_model_name,
|
||||
fast_default_model_name=fast_default_model_name,
|
||||
)
|
||||
|
||||
for model_name, is_visible in model_configurations:
|
||||
connection.execute(
|
||||
model_configuration_table.insert().values(
|
||||
llm_provider_id=provider_id,
|
||||
name=model_name,
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
)
|
||||
|
||||
op.drop_column("llm_provider", "model_names")
|
||||
op.drop_column("llm_provider", "display_model_names")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
llm_provider = sa.table(
|
||||
"llm_provider",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("model_names", postgresql.ARRAY(sa.String)),
|
||||
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
|
||||
)
|
||||
|
||||
model_configuration = sa.table(
|
||||
"model_configuration",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("llm_provider_id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_visible", sa.Boolean),
|
||||
sa.column("max_input_tokens", sa.Integer),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column(
|
||||
"model_names",
|
||||
postgresql.ARRAY(sa.VARCHAR()),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column(
|
||||
"display_model_names",
|
||||
postgresql.ARRAY(sa.VARCHAR()),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
connection = op.get_bind()
|
||||
provider_ids = connection.execute(sa.select(llm_provider.c.id)).fetchall()
|
||||
|
||||
for (provider_id,) in provider_ids:
|
||||
# Get all models for this provider
|
||||
models = connection.execute(
|
||||
sa.select(
|
||||
model_configuration.c.name, model_configuration.c.is_visible
|
||||
).where(model_configuration.c.llm_provider_id == provider_id)
|
||||
).fetchall()
|
||||
|
||||
all_models = [model[0] for model in models]
|
||||
visible_models = [model[0] for model in models if model[1]]
|
||||
|
||||
# Update provider with arrays
|
||||
op.execute(
|
||||
llm_provider.update()
|
||||
.where(llm_provider.c.id == provider_id)
|
||||
.values(model_names=all_models, display_model_names=visible_models)
|
||||
)
|
||||
|
||||
op.drop_table("model_configuration")
|
||||
@@ -103,6 +103,7 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "is_user_file")
|
||||
# Drop the persona__user_folder table
|
||||
op.drop_table("persona__user_folder")
|
||||
# Drop the persona__user_file table
|
||||
@@ -111,4 +112,3 @@ def downgrade() -> None:
|
||||
op.drop_table("user_file")
|
||||
# Drop the user_folder table
|
||||
op.drop_table("user_folder")
|
||||
op.drop_column("connector_credential_pair", "is_user_file")
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Add public_external_user_group table
|
||||
|
||||
Revision ID: a7688ab35c45
|
||||
Revises: 5c448911b12f
|
||||
Create Date: 2025-05-06 20:55:12.747875
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a7688ab35c45"
|
||||
down_revision = "5c448911b12f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"public_external_user_group",
|
||||
sa.Column("external_user_group_id", sa.String(), nullable=False),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("external_user_group_id", "cc_pair_id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["cc_pair_id"], ["connector_credential_pair.id"], ondelete="CASCADE"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("public_external_user_group")
|
||||
@@ -0,0 +1,128 @@
|
||||
"""add_cascade_deletes_to_agent_tables
|
||||
|
||||
Revision ID: ca04500b9ee8
|
||||
Revises: 238b84885828
|
||||
Create Date: 2025-05-30 16:03:51.112263
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ca04500b9ee8"
|
||||
down_revision = "238b84885828"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop existing foreign key constraints
|
||||
op.drop_constraint(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate foreign key constraints with CASCADE delete
|
||||
op.create_foreign_key(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
"chat_message",
|
||||
["primary_question_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
"agent__sub_question",
|
||||
["parent_question_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"agent__sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop CASCADE foreign key constraints
|
||||
op.drop_constraint(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate foreign key constraints without CASCADE delete
|
||||
op.create_foreign_key(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
"chat_message",
|
||||
["primary_question_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
"agent__sub_question",
|
||||
["parent_question_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"agent__sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
29
backend/alembic/versions/cec7ec36c505_kgentity_parent.py
Normal file
29
backend/alembic/versions/cec7ec36c505_kgentity_parent.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""kgentity_parent
|
||||
|
||||
Revision ID: cec7ec36c505
|
||||
Revises: 495cb26ce93e
|
||||
Create Date: 2025-06-07 20:07:46.400770
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "cec7ec36c505"
|
||||
down_revision = "495cb26ce93e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"kg_entity",
|
||||
sa.Column("parent_key", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
# NOTE: you will have to reindex the KG after this migration as the parent_key will be null
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("kg_entity", "parent_key")
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Update status length
|
||||
|
||||
Revision ID: d961aca62eb3
|
||||
Revises: cf90764725d8
|
||||
Create Date: 2025-03-23 16:10:05.683965
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d961aca62eb3"
|
||||
down_revision = "cf90764725d8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing enum type constraint
|
||||
op.execute("ALTER TABLE connector_credential_pair ALTER COLUMN status TYPE varchar")
|
||||
|
||||
# Create new enum type with all values
|
||||
op.execute(
|
||||
"ALTER TABLE connector_credential_pair ALTER COLUMN status TYPE VARCHAR(20) USING status::varchar(20)"
|
||||
)
|
||||
|
||||
# Update the enum type to include all possible values
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"status",
|
||||
type_=sa.Enum(
|
||||
"SCHEDULED",
|
||||
"INITIAL_INDEXING",
|
||||
"ACTIVE",
|
||||
"PAUSED",
|
||||
"DELETING",
|
||||
"INVALID",
|
||||
name="connectorcredentialpairstatus",
|
||||
native_enum=False,
|
||||
),
|
||||
existing_type=sa.String(20),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"in_repeated_error_state", sa.Boolean, default=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# no need to convert back to the old enum type, since we're not using it anymore
|
||||
op.drop_column("connector_credential_pair", "in_repeated_error_state")
|
||||
@@ -21,6 +21,9 @@ branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
PRESERVED_CONFIG_KEYS = ["comment_email_blacklist", "batch_size", "labels_to_skip"]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get all Jira connectors
|
||||
conn = op.get_bind()
|
||||
@@ -62,6 +65,9 @@ def upgrade() -> None:
|
||||
f"WARNING: Jira connector {connector_id} has no project URL configured"
|
||||
)
|
||||
continue
|
||||
for old_key in PRESERVED_CONFIG_KEYS:
|
||||
if old_key in old_config:
|
||||
new_config[old_key] = old_config[old_key]
|
||||
|
||||
# Update the connector config
|
||||
conn.execute(
|
||||
@@ -108,6 +114,10 @@ def downgrade() -> None:
|
||||
else:
|
||||
continue
|
||||
|
||||
for old_key in PRESERVED_CONFIG_KEYS:
|
||||
if old_key in new_config:
|
||||
old_config[old_key] = new_config[old_key]
|
||||
|
||||
# Update the connector config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
@@ -117,5 +127,5 @@ def downgrade() -> None:
|
||||
WHERE id = :id
|
||||
"""
|
||||
),
|
||||
{"id": connector_id, "old_config": old_config},
|
||||
{"id": connector_id, "old_config": json.dumps(old_config)},
|
||||
)
|
||||
|
||||
@@ -10,12 +10,19 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from onyx.db.search_settings import (
|
||||
get_new_default_embedding_model,
|
||||
get_old_default_embedding_model,
|
||||
user_has_overridden_embedding_model,
|
||||
)
|
||||
from onyx.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from onyx.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from onyx.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from onyx.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from onyx.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.search_settings import user_has_overridden_embedding_model
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.natural_language_processing.search_nlp_models import clean_model_name
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dbaa756c2ccf"
|
||||
@@ -24,6 +31,47 @@ branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def _get_old_default_embedding_model() -> IndexingSetting:
|
||||
is_overridden = user_has_overridden_embedding_model()
|
||||
return IndexingSetting(
|
||||
model_name=(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
),
|
||||
model_dim=(
|
||||
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
),
|
||||
embedding_precision=(EmbeddingPrecision.FLOAT),
|
||||
normalize=(
|
||||
NORMALIZE_EMBEDDINGS
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
),
|
||||
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
index_name="danswer_chunk",
|
||||
multipass_indexing=False,
|
||||
enable_contextual_rag=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
|
||||
def _get_new_default_embedding_model() -> IndexingSetting:
|
||||
return IndexingSetting(
|
||||
model_name=DOCUMENT_ENCODER_MODEL,
|
||||
model_dim=DOC_EMBEDDING_DIM,
|
||||
embedding_precision=(EmbeddingPrecision.BFLOAT16),
|
||||
normalize=NORMALIZE_EMBEDDINGS,
|
||||
query_prefix=ASYM_QUERY_PREFIX,
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
multipass_indexing=False,
|
||||
enable_contextual_rag=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"embedding_model",
|
||||
@@ -61,7 +109,7 @@ def upgrade() -> None:
|
||||
# the user selected via env variables before this change. This is needed since
|
||||
# all index_attempts must be associated with an embedding model, so without this
|
||||
# we will run into violations of non-null contraints
|
||||
old_embedding_model = get_old_default_embedding_model()
|
||||
old_embedding_model = _get_old_default_embedding_model()
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
@@ -79,7 +127,7 @@ def upgrade() -> None:
|
||||
# if the user has not overridden the default embedding model via env variables,
|
||||
# insert the new default model into the database to auto-upgrade them
|
||||
if not user_has_overridden_embedding_model():
|
||||
new_embedding_model = get_new_default_embedding_model()
|
||||
new_embedding_model = _get_new_default_embedding_model()
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""add_db_readonly_user
|
||||
|
||||
Revision ID: 3b9f09038764
|
||||
Revises: 3b45e0018bf1
|
||||
Create Date: 2025-05-11 11:05:11.436977
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b9f09038764"
|
||||
down_revision = "3b45e0018bf1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
@@ -1,12 +1,10 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user
|
||||
from ee.onyx.db.external_perm import fetch_public_external_group_ids
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_documents
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
@@ -17,6 +15,10 @@ from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_access_for_document(
|
||||
@@ -63,13 +65,21 @@ def _get_access_for_documents(
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
all_public_ext_u_group_ids = set(fetch_public_external_group_ids(db_session))
|
||||
|
||||
access_map = {}
|
||||
for document_id, non_ee_access in non_ee_access_dict.items():
|
||||
document = doc_id_map[document_id]
|
||||
source = doc_id_to_source_map.get(document_id)
|
||||
if source is None:
|
||||
logger.error(f"Document {document_id} has no source")
|
||||
continue
|
||||
|
||||
perm_sync_config = get_source_perm_sync_config(source)
|
||||
is_only_censored = (
|
||||
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
and source not in DOC_PERMISSIONS_FUNC_MAP
|
||||
perm_sync_config
|
||||
and perm_sync_config.censoring_config is not None
|
||||
and perm_sync_config.doc_sync_config is None
|
||||
)
|
||||
|
||||
ext_u_emails = (
|
||||
@@ -89,7 +99,10 @@ def _get_access_for_documents(
|
||||
# If its censored, then it's public anywhere during the search and then permissions are
|
||||
# applied after the search
|
||||
is_public_anywhere = (
|
||||
document.is_public or non_ee_access.is_public or is_only_censored
|
||||
document.is_public
|
||||
or non_ee_access.is_public
|
||||
or is_only_censored
|
||||
or any(u_group in all_public_ext_u_group_ids for u_group in ext_u_groups)
|
||||
)
|
||||
|
||||
# To avoid collisions of group namings between connectors, they need to be prefixed
|
||||
|
||||
123
backend/ee/onyx/background/celery/apps/heavy.py
Normal file
123
backend/ee/onyx/background/celery/apps/heavy.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
|
||||
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.background.celery.apps.heavy import celery_app
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import mark_task_as_started_with_id
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def export_query_history_task(
|
||||
self: Task, *, start: datetime, end: datetime, start_time: datetime
|
||||
) -> None:
|
||||
if not self.request.id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
task_id = self.request.id
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream,
|
||||
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
mark_task_as_started_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
snapshot_generator = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
for snapshot in snapshot_generator:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
writer.writerows(
|
||||
qa_pair.to_json()
|
||||
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
|
||||
snapshot
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to export query history with {task_id=}")
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
report_name = construct_query_history_report_name(task_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store(db_session).save_file(
|
||||
file_name=report_name,
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
file_type=FileType.CSV,
|
||||
file_metadata={
|
||||
"start": start.isoformat(),
|
||||
"end": end.isoformat(),
|
||||
"start_time": start_time.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
delete_task_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to save query history export file; {report_name=}"
|
||||
)
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
]
|
||||
)
|
||||
8
backend/ee/onyx/background/celery/apps/light.py
Normal file
8
backend/ee/onyx/background/celery/apps/light.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from onyx.background.celery.apps.light import celery_app
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
)
|
||||
7
backend/ee/onyx/background/celery/apps/monitoring.py
Normal file
7
backend/ee/onyx/background/celery/apps/monitoring.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from onyx.background.celery.apps.monitoring import celery_app
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
@@ -1,12 +1,22 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.onyx.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.background.task_utils import build_celery_task_wrapper
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -15,18 +25,42 @@ logger = setup_logger()
|
||||
# mark as EE for all tasks in this file
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str
|
||||
) -> None:
|
||||
task_id = self.request.id
|
||||
if not task_id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
user_id: UUID | None = None
|
||||
session_id: UUID | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
# we generally want to move off this, but keeping for now
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
@@ -34,11 +68,26 @@ def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) ->
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
#####
|
||||
@@ -47,7 +96,7 @@ def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) ->
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_ttl_management_task",
|
||||
name=OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
@@ -67,7 +116,7 @@ def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="autogenerate_usage_report_task",
|
||||
name=OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
@@ -79,3 +128,12 @@ def autogenerate_usage_report_task(*, tenant_id: str) -> None:
|
||||
user_id=None,
|
||||
period=None,
|
||||
)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
beat_cloud_tasks as base_beat_system_tasks,
|
||||
)
|
||||
@@ -13,6 +14,7 @@ from onyx.background.celery.tasks.beat_schedule import (
|
||||
get_tasks_to_schedule as base_get_tasks_to_schedule,
|
||||
)
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -33,10 +35,20 @@ ee_beat_task_templates.extend(
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "export-query-history-cleanup-task",
|
||||
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.CSV_GENERATION,
|
||||
},
|
||||
},
|
||||
]
|
||||
@@ -58,10 +70,20 @@ if not MULTI_TENANT:
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "export-query-history-cleanup-task",
|
||||
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.CSV_GENERATION,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
40
backend/ee/onyx/background/celery/tasks/cleanup/tasks.py
Normal file
40
backend/ee/onyx/background/celery/tasks/cleanup/tasks.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
from ee.onyx.db.query_history import get_all_query_history_export_tasks
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def export_query_history_cleanup_task(*, tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
tasks = get_all_query_history_export_tasks(db_session=db_session)
|
||||
|
||||
for task in tasks:
|
||||
if task.status == TaskStatus.SUCCESS:
|
||||
delete_task_with_id(db_session=db_session, task_id=task.task_id)
|
||||
elif task.status == TaskStatus.FAILURE:
|
||||
if task.start_time:
|
||||
deadline = task.start_time + timedelta(hours=24)
|
||||
now = datetime.now()
|
||||
if now < deadline:
|
||||
continue
|
||||
|
||||
logger.error(
|
||||
f"Task with {task.task_id=} failed; it is being deleted now"
|
||||
)
|
||||
delete_task_with_id(db_session=db_session, task_id=task.task_id)
|
||||
104
backend/ee/onyx/background/celery/tasks/cloud/tasks.py
Normal file
104
backend/ee/onyx/background/celery/tasks/cloud/tasks.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import time
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
ignore_result=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def cloud_beat_task_generator(
|
||||
self: Task,
|
||||
task_name: str,
|
||||
queue: str = OnyxCeleryTask.DEFAULT,
|
||||
priority: int = OnyxCeleryPriority.MEDIUM,
|
||||
expires: int = BEAT_EXPIRES_DEFAULT,
|
||||
) -> bool | None:
|
||||
"""a lightweight task used to kick off individual beat tasks per tenant."""
|
||||
time_start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
f"{OnyxRedisLocks.CLOUD_BEAT_TASK_GENERATOR_LOCK}:{task_name}",
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
gated_tenants = get_gated_tenants()
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated_tenants:
|
||||
continue
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# needed in the cloud
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
continue
|
||||
|
||||
self.app.send_task(
|
||||
task_name,
|
||||
kwargs=dict(
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
priority=priority,
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_processed_tenants += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during cloud_beat_task_generator")
|
||||
finally:
|
||||
if not lock_beat.owned():
|
||||
task_logger.error(
|
||||
"cloud_beat_task_generator - Lock not owned on completion"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
else:
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
@@ -16,22 +16,20 @@ from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
from tenacity import retry
|
||||
from tenacity import retry_if_exception
|
||||
from tenacity import stop_after_delay
|
||||
from tenacity import wait_random_exponential
|
||||
|
||||
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.onyx.external_permissions.sync_params import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
@@ -47,8 +45,10 @@ from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
@@ -57,6 +57,7 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.db.utils import is_retryable_sqlalchemy_error
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
@@ -73,11 +74,12 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
|
||||
DOCUMENT_PERMISSIONS_UPDATE_STOP_AFTER = 10 * 60
|
||||
DOCUMENT_PERMISSIONS_UPDATE_MAX_WAIT = 60
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
@@ -98,16 +100,29 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
sync_config = get_source_perm_sync_config(cc_pair.connector.source)
|
||||
if sync_config is None:
|
||||
logger.error(f"No sync config found for {cc_pair.connector.source}")
|
||||
return False
|
||||
|
||||
if sync_config.doc_sync_config is None:
|
||||
logger.error(f"No doc sync config found for {cc_pair.connector.source}")
|
||||
return False
|
||||
|
||||
# if indexing also does perm sync, don't start running doc_sync until at
|
||||
# least one indexing is done
|
||||
if (
|
||||
sync_config.doc_sync_config.initial_index_should_sync
|
||||
and cc_pair.last_successful_index_time is None
|
||||
):
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
last_perm_sync = cc_pair.last_time_perm_sync
|
||||
if last_perm_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
if not source_sync_period:
|
||||
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
|
||||
source_sync_period = sync_config.doc_sync_config.doc_sync_frequency
|
||||
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
@@ -425,11 +440,15 @@ def connector_permission_sync_generator_task(
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
sync_config = get_source_perm_sync_config(source_type)
|
||||
if sync_config is None:
|
||||
logger.error(f"No sync config found for {source_type}")
|
||||
return None
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
if source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION:
|
||||
if sync_config.doc_sync_config is None:
|
||||
if sync_config.censoring_config:
|
||||
return None
|
||||
|
||||
raise ValueError(
|
||||
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -449,7 +468,22 @@ def connector_permission_sync_generator_task(
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
document_external_accesses = doc_sync_func(cc_pair, callback)
|
||||
|
||||
# pass in the capability to fetch all existing docs for the cc_pair
|
||||
# this is can be used to determine documents that are "missing" and thus
|
||||
# should no longer be accessible. The decision as to whether we should find
|
||||
# every document during the doc sync process is connector-specific.
|
||||
def fetch_all_existing_docs_fn() -> list[str]:
|
||||
return get_document_ids_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
|
||||
doc_sync_func = sync_config.doc_sync_config.doc_sync_func
|
||||
document_external_accesses = doc_sync_func(
|
||||
cc_pair, fetch_all_existing_docs_fn, callback
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
@@ -457,13 +491,13 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
tasks_generated = 0
|
||||
for doc_external_access in document_external_accesses:
|
||||
redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
redis_connector.permissions.update_db(
|
||||
lock=lock,
|
||||
new_permissions=[doc_external_access],
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
task_logger=task_logger,
|
||||
)
|
||||
tasks_generated += 1
|
||||
|
||||
@@ -476,6 +510,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
|
||||
task_logger.warning(
|
||||
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}"
|
||||
)
|
||||
@@ -496,33 +531,28 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
bind=True,
|
||||
# NOTE(rkuo): this should probably move to the db layer
|
||||
@retry(
|
||||
retry=retry_if_exception(is_retryable_sqlalchemy_error),
|
||||
wait=wait_random_exponential(
|
||||
multiplier=1, max=DOCUMENT_PERMISSIONS_UPDATE_MAX_WAIT
|
||||
),
|
||||
stop=stop_after_delay(DOCUMENT_PERMISSIONS_UPDATE_STOP_AFTER),
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
def document_update_permissions(
|
||||
tenant_id: str,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
permissions: DocExternalAccess,
|
||||
source_type_str: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
doc_id = document_external_access.doc_id
|
||||
external_access = document_external_access.external_access
|
||||
doc_id = permissions.doc_id
|
||||
external_access = permissions.external_access
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
@@ -534,7 +564,7 @@ def update_external_document_permissions_task(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_string),
|
||||
source_type=DocumentSource(source_type_str),
|
||||
)
|
||||
|
||||
if created_new_doc:
|
||||
@@ -553,32 +583,105 @@ def update_external_document_permissions_task(
|
||||
f"action=update_permissions "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(
|
||||
f"update_external_document_permissions_task exceptioned: "
|
||||
f"document_update_permissions exceptioned: "
|
||||
f"connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
raise e
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
|
||||
f"document_update_permissions completed: connector_id={connector_id} doc={doc_id}"
|
||||
)
|
||||
|
||||
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
|
||||
return False
|
||||
|
||||
task_logger.info(
|
||||
f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# NOTE(rkuo): Deprecating this due to degenerate behavior in Redis from sending
|
||||
# large permissions through celery (over 1MB in size)
|
||||
# @shared_task(
|
||||
# name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
# soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
# time_limit=LIGHT_TIME_LIMIT,
|
||||
# max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
# bind=True,
|
||||
# )
|
||||
# def update_external_document_permissions_task(
|
||||
# self: Task,
|
||||
# tenant_id: str,
|
||||
# serialized_doc_external_access: dict,
|
||||
# source_string: str,
|
||||
# connector_id: int,
|
||||
# credential_id: int,
|
||||
# ) -> bool:
|
||||
# start = time.monotonic()
|
||||
|
||||
# completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
# document_external_access = DocExternalAccess.from_dict(
|
||||
# serialized_doc_external_access
|
||||
# )
|
||||
# doc_id = document_external_access.doc_id
|
||||
# external_access = document_external_access.external_access
|
||||
|
||||
# try:
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# # Add the users to the DB if they don't exist
|
||||
# batch_add_ext_perm_user_if_not_exists(
|
||||
# db_session=db_session,
|
||||
# emails=list(external_access.external_user_emails),
|
||||
# continue_on_error=True,
|
||||
# )
|
||||
# # Then upsert the document's external permissions
|
||||
# created_new_doc = upsert_document_external_perms(
|
||||
# db_session=db_session,
|
||||
# doc_id=doc_id,
|
||||
# external_access=external_access,
|
||||
# source_type=DocumentSource(source_string),
|
||||
# )
|
||||
|
||||
# if created_new_doc:
|
||||
# # If a new document was created, we associate it with the cc_pair
|
||||
# upsert_document_by_connector_credential_pair(
|
||||
# db_session=db_session,
|
||||
# connector_id=connector_id,
|
||||
# credential_id=credential_id,
|
||||
# document_ids=[doc_id],
|
||||
# )
|
||||
|
||||
# elapsed = time.monotonic() - start
|
||||
# task_logger.info(
|
||||
# f"connector_id={connector_id} "
|
||||
# f"doc={doc_id} "
|
||||
# f"action=update_permissions "
|
||||
# f"elapsed={elapsed:.2f}"
|
||||
# )
|
||||
|
||||
# completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
# except Exception as e:
|
||||
# error_msg = format_error_for_logging(e)
|
||||
# task_logger.warning(
|
||||
# f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
|
||||
# )
|
||||
# task_logger.exception(
|
||||
# f"update_external_document_permissions_task exceptioned: "
|
||||
# f"connector_id={connector_id} doc_id={doc_id}"
|
||||
# )
|
||||
# completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
# finally:
|
||||
# task_logger.info(
|
||||
# f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
|
||||
# )
|
||||
|
||||
# if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
|
||||
# return False
|
||||
|
||||
# task_logger.info(
|
||||
# f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
|
||||
# )
|
||||
# return True
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
@@ -0,0 +1,30 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.external_permissions.sync_params import (
|
||||
source_group_sync_is_cc_pair_agnostic,
|
||||
)
|
||||
from onyx.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_source
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
|
||||
|
||||
def _get_all_cc_pair_ids_to_mark_as_group_synced(
|
||||
db_session: Session, cc_pair: ConnectorCredentialPair
|
||||
) -> list[int]:
|
||||
if not source_group_sync_is_cc_pair_agnostic(cc_pair.connector.source):
|
||||
return [cc_pair.id]
|
||||
|
||||
cc_pairs = get_connector_credential_pairs_for_source(
|
||||
db_session, cc_pair.connector.source
|
||||
)
|
||||
return [cc_pair.id for cc_pair in cc_pairs]
|
||||
|
||||
|
||||
def mark_all_relevant_cc_pairs_as_external_group_synced(
|
||||
db_session: Session, cc_pair: ConnectorCredentialPair
|
||||
) -> None:
|
||||
"""For some source types, one successful group sync run should count for all
|
||||
cc pairs of that type. This function handles that case."""
|
||||
cc_pair_ids = _get_all_cc_pair_ids_to_mark_as_group_synced(db_session, cc_pair)
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair_id)
|
||||
@@ -14,15 +14,17 @@ from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.background.celery.tasks.external_group_syncing.group_sync_utils import (
|
||||
mark_all_relevant_cc_pairs_as_external_group_synced,
|
||||
)
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.onyx.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
|
||||
from ee.onyx.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.onyx.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
get_all_cc_pair_agnostic_group_sync_sources,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -38,8 +40,6 @@ from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -88,12 +88,20 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
)
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
|
||||
sync_config = get_source_perm_sync_config(cc_pair.connector.source)
|
||||
if sync_config is None:
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"no group sync function for {cc_pair.connector.source}"
|
||||
f"no sync config found for {cc_pair.connector.source}"
|
||||
)
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if sync_config.group_sync_config is None:
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"no group sync config found for {cc_pair.connector.source}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -102,11 +110,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
if last_ext_group_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
source_sync_period = sync_config.group_sync_config.group_sync_frequency
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
|
||||
@@ -146,9 +150,8 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# We only want to sync one cc_pair per source type in
|
||||
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
|
||||
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
|
||||
# For some sources, we only want to sync one cc_pair per source type
|
||||
for source in get_all_cc_pair_agnostic_group_sync_sources():
|
||||
# These are ordered by cc_pair id so the first one is the one we want
|
||||
cc_pairs_to_dedupe = get_cc_pairs_by_source(
|
||||
db_session,
|
||||
@@ -156,8 +159,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
access_type=AccessType.SYNC,
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
)
|
||||
# We only want to sync one cc_pair per source type
|
||||
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
|
||||
# dedupe cc_pairs to only keep the first one
|
||||
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
|
||||
cc_pairs = [
|
||||
cc_pair
|
||||
@@ -386,32 +388,22 @@ def connector_external_group_sync_generator_task(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
created = validate_ccpair_for_user(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
task_logger.warning(
|
||||
f"Unable to create connector credential pair for id: {cc_pair_id}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
|
||||
)
|
||||
# TODO: add some notification to the admins here
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
msg = f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
sync_config = get_source_perm_sync_config(source_type)
|
||||
if sync_config is None:
|
||||
msg = (
|
||||
f"No sync config found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
if sync_config.group_sync_config is None:
|
||||
msg = f"No group sync config found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
@@ -440,7 +432,7 @@ def connector_external_group_sync_generator_task(
|
||||
f"Synced {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
@@ -14,9 +14,8 @@ from ee.onyx.server.tenants.provisioning import setup_tenant
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
@@ -39,7 +38,8 @@ _TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
name=OnyxCeleryTask.CLOUD_CHECK_AVAILABLE_TENANTS,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
@@ -55,7 +55,7 @@ def check_available_tenants(self: Task) -> None:
|
||||
)
|
||||
return
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_check: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
@@ -71,32 +71,28 @@ def check_available_tenants(self: Task) -> None:
|
||||
try:
|
||||
# Get the current count of available tenants
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
available_tenants_count = db_session.query(AvailableTenant).count()
|
||||
num_available_tenants = db_session.query(AvailableTenant).count()
|
||||
|
||||
# Get the target number of available tenants
|
||||
target_available_tenants = getattr(
|
||||
num_minimum_available_tenants = getattr(
|
||||
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
|
||||
)
|
||||
|
||||
# Calculate how many new tenants we need to provision
|
||||
tenants_to_provision = max(
|
||||
0, target_available_tenants - available_tenants_count
|
||||
)
|
||||
if num_available_tenants < num_minimum_available_tenants:
|
||||
tenants_to_provision = num_minimum_available_tenants - num_available_tenants
|
||||
else:
|
||||
tenants_to_provision = 0
|
||||
|
||||
task_logger.info(
|
||||
f"Available tenants: {available_tenants_count}, "
|
||||
f"Target: {target_available_tenants}, "
|
||||
f"Available tenants: {num_available_tenants}, "
|
||||
f"Target minimum available tenants: {num_minimum_available_tenants}, "
|
||||
f"To provision: {tenants_to_provision}"
|
||||
)
|
||||
|
||||
# Trigger pre-provisioning tasks for each tenant needed
|
||||
for _ in range(tenants_to_provision):
|
||||
from celery import current_app
|
||||
|
||||
current_app.send_task(
|
||||
OnyxCeleryTask.PRE_PROVISION_TENANT,
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
# just provision one tenant each time we run this ... increase if needed.
|
||||
if tenants_to_provision > 0:
|
||||
pre_provision_tenant()
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
@@ -105,15 +101,7 @@ def check_available_tenants(self: Task) -> None:
|
||||
lock_check.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PRE_PROVISION_TENANT,
|
||||
ignore_result=True,
|
||||
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def pre_provision_tenant(self: Task) -> None:
|
||||
def pre_provision_tenant() -> None:
|
||||
"""
|
||||
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||
This function fully sets up the tenant with all necessary configurations,
|
||||
@@ -122,9 +110,9 @@ def pre_provision_tenant(self: Task) -> None:
|
||||
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
|
||||
# rather than inside this function
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_provision: RedisLock = r.lock(
|
||||
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
|
||||
OnyxRedisLocks.CLOUD_PRE_PROVISION_TENANT_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def should_perform_chat_ttl_check(
|
||||
retention_limit_days: int | None, db_session: Session
|
||||
retention_limit_days: float | None, db_session: Session
|
||||
) -> bool:
|
||||
# TODO: make this a check for None and add behavior for 0 day TTL
|
||||
if not retention_limit_days:
|
||||
|
||||
@@ -1,2 +1,16 @@
|
||||
def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None) -> str:
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
|
||||
QUERY_HISTORY_TASK_NAME_PREFIX = OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK
|
||||
|
||||
|
||||
def name_chat_ttl_task(
|
||||
retention_limit_days: float, tenant_id: str | None = None
|
||||
) -> str:
|
||||
return f"chat_ttl_{retention_limit_days}_days"
|
||||
|
||||
|
||||
def query_history_task_name(start: datetime, end: datetime) -> str:
|
||||
return f"{QUERY_HISTORY_TASK_NAME_PREFIX}_{start}_{end}"
|
||||
|
||||
@@ -25,13 +25,25 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
|
||||
#####
|
||||
# Auto Permission Sync
|
||||
#####
|
||||
# should generally only be used for sources that support polling of permissions
|
||||
# e.g. can pull in only permission changes rather than having to go through all
|
||||
# documents every time
|
||||
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
# In seconds, default is 5 minutes
|
||||
|
||||
#####
|
||||
# Confluence
|
||||
#####
|
||||
|
||||
# In seconds, default is 30 minutes
|
||||
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 30 * 60
|
||||
)
|
||||
# In seconds, default is 30 minutes
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
|
||||
)
|
||||
# This is a boolean that determines if anonymous access is public
|
||||
# Default behavior is to not make the page public and instead add a group
|
||||
@@ -39,14 +51,34 @@ CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
os.environ.get("CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC", "").lower() == "true"
|
||||
)
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
|
||||
|
||||
#####
|
||||
# Google Drive
|
||||
#####
|
||||
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Slack
|
||||
#####
|
||||
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
||||
|
||||
####
|
||||
# Celery Job Frequency
|
||||
####
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
os.environ.get("CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS") or 1
|
||||
) # float for easier testing
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
@@ -62,29 +94,6 @@ JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", "[]"))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
|
||||
)
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
|
||||
)
|
||||
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
# The posthog client does not accept empty API keys or hosts however it fails silently
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
@@ -92,6 +101,4 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
@@ -140,7 +140,7 @@ def fetch_onyxbot_analytics(
|
||||
(
|
||||
or_(
|
||||
ChatMessageFeedback.is_positive.is_(False),
|
||||
ChatMessageFeedback.required_followup,
|
||||
ChatMessageFeedback.required_followup.is_(True),
|
||||
),
|
||||
1,
|
||||
),
|
||||
@@ -173,7 +173,7 @@ def fetch_onyxbot_analytics(
|
||||
.all()
|
||||
)
|
||||
|
||||
return results
|
||||
return [tuple(row) for row in results]
|
||||
|
||||
|
||||
def fetch_persona_message_analytics(
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import PublicExternalUserGroup
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
@@ -20,6 +21,12 @@ logger = setup_logger()
|
||||
class ExternalUserGroup(BaseModel):
|
||||
id: str
|
||||
user_emails: list[str]
|
||||
# `True` for cases like a Folder in Google Drive that give domain-wide
|
||||
# or "Anyone with link" access to all files in the folder.
|
||||
# if this is set, `user_emails` don't really matter.
|
||||
# When this is `True`, this `ExternalUserGroup` object doesn't really represent
|
||||
# an actual "group" in the source.
|
||||
gives_anyone_access: bool = False
|
||||
|
||||
|
||||
def delete_user__ext_group_for_user__no_commit(
|
||||
@@ -44,6 +51,17 @@ def delete_user__ext_group_for_cc_pair__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def delete_public_external_group_for_cc_pair__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(PublicExternalUserGroup).where(
|
||||
PublicExternalUserGroup.cc_pair_id == cc_pair_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def replace_user__ext_group_for_cc_pair(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
@@ -72,13 +90,22 @@ def replace_user__ext_group_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
delete_public_external_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# map emails to ids
|
||||
email_id_map = {user.email: user.id for user in all_group_members}
|
||||
|
||||
# use these ids to create new external user group relations relating group_id to user_ids
|
||||
new_external_permissions = []
|
||||
new_external_permissions: list[User__ExternalUserGroupId] = []
|
||||
new_public_external_groups: list[PublicExternalUserGroup] = []
|
||||
for external_group in group_defs:
|
||||
external_group_id = build_ext_group_name_for_onyx(
|
||||
ext_group_name=external_group.id,
|
||||
source=source,
|
||||
)
|
||||
for user_email in external_group.user_emails:
|
||||
user_id = email_id_map.get(user_email.lower())
|
||||
if user_id is None:
|
||||
@@ -87,10 +114,6 @@ def replace_user__ext_group_for_cc_pair(
|
||||
f" with email {user_email} not found"
|
||||
)
|
||||
continue
|
||||
external_group_id = build_ext_group_name_for_onyx(
|
||||
ext_group_name=external_group.id,
|
||||
source=source,
|
||||
)
|
||||
new_external_permissions.append(
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
@@ -99,7 +122,16 @@ def replace_user__ext_group_for_cc_pair(
|
||||
)
|
||||
)
|
||||
|
||||
if external_group.gives_anyone_access:
|
||||
new_public_external_groups.append(
|
||||
PublicExternalUserGroup(
|
||||
external_user_group_id=external_group_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
)
|
||||
|
||||
db_session.add_all(new_external_permissions)
|
||||
db_session.add_all(new_public_external_groups)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -130,3 +162,11 @@ def fetch_external_groups_for_user_email_and_group_ids(
|
||||
)
|
||||
).all()
|
||||
return list(user_ext_groups)
|
||||
|
||||
|
||||
def fetch_public_external_group_ids(
|
||||
db_session: Session,
|
||||
) -> list[str]:
|
||||
return list(
|
||||
db_session.scalars(select(PublicExternalUserGroup.external_user_group_id)).all()
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
def make_persona_private(
|
||||
persona_id: int,
|
||||
creator_user_id: UUID | None,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
@@ -29,15 +30,15 @@ def make_persona_private(
|
||||
user_ids_set = set(user_ids)
|
||||
for user_id in user_ids_set:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
|
||||
|
||||
create_notification(
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
).model_dump(),
|
||||
)
|
||||
if user_id != creator_user_id:
|
||||
create_notification(
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if group_ids:
|
||||
group_ids_set = set(group_ids)
|
||||
|
||||
@@ -15,10 +15,13 @@ from sqlalchemy.sql import select
|
||||
from sqlalchemy.sql.expression import literal
|
||||
from sqlalchemy.sql.expression import UnaryExpression
|
||||
|
||||
from ee.onyx.background.task_name_builders import QUERY_HISTORY_TASK_NAME_PREFIX
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessageFeedback
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import TaskQueueState
|
||||
from onyx.db.tasks import get_all_tasks_with_prefix
|
||||
|
||||
|
||||
def _build_filter_conditions(
|
||||
@@ -171,3 +174,9 @@ def fetch_chat_sessions_eagerly_by_time(
|
||||
chat_sessions = query.all()
|
||||
|
||||
return chat_sessions
|
||||
|
||||
|
||||
def get_all_query_history_export_tasks(
|
||||
db_session: Session,
|
||||
) -> list[TaskQueueState]:
|
||||
return get_all_tasks_with_prefix(db_session, QUERY_HISTORY_TASK_NAME_PREFIX)
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import IO
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_users_db_sqlalchemy import UUID_ID
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
|
||||
@@ -13,6 +15,7 @@ from ee.onyx.server.reporting.usage_export_models import FlowType
|
||||
from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.models import UsageReport
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
|
||||
|
||||
@@ -86,15 +89,27 @@ def get_all_empty_chat_message_entries(
|
||||
|
||||
|
||||
def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
|
||||
# Get the user emails
|
||||
usage_reports = db_session.query(UsageReport).all()
|
||||
user_ids = {r.requestor_user_id for r in usage_reports if r.requestor_user_id}
|
||||
user_emails = {
|
||||
user.id: user.email
|
||||
for user in db_session.query(User)
|
||||
.filter(cast(User.id, UUID).in_(user_ids))
|
||||
.all()
|
||||
}
|
||||
|
||||
return [
|
||||
UsageReportMetadata(
|
||||
report_name=r.report_name,
|
||||
requestor=str(r.requestor_user_id) if r.requestor_user_id else None,
|
||||
requestor=(
|
||||
user_emails.get(r.requestor_user_id) if r.requestor_user_id else None
|
||||
),
|
||||
time_created=r.time_created,
|
||||
period_from=r.period_from,
|
||||
period_to=r.period_to,
|
||||
)
|
||||
for r in db_session.query(UsageReport).all()
|
||||
for r in usage_reports
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<services version="1.0">
|
||||
<container id="default" version="1.0">
|
||||
<document-api />
|
||||
<search />
|
||||
<http>
|
||||
<server id="default" port="4080" />
|
||||
</http>
|
||||
<nodes count="[2, 4]">
|
||||
<resources vcpu="4.0" memory="16Gb" architecture="arm64" storage-type="remote"
|
||||
disk="48Gb" />
|
||||
</nodes>
|
||||
|
||||
|
||||
</container>
|
||||
<content id="danswer_index" version="1.0">
|
||||
<documents>
|
||||
<!-- <document type="danswer_chunk" mode="index" /> -->
|
||||
{{ document_elements }}
|
||||
</documents>
|
||||
<nodes count="75">
|
||||
<resources vcpu="8.0" memory="64.0Gb" architecture="arm64" storage-type="local"
|
||||
disk="474.0Gb" />
|
||||
</nodes>
|
||||
<engine>
|
||||
<proton>
|
||||
<tuning>
|
||||
<searchnode>
|
||||
<requestthreads>
|
||||
<persearch>2</persearch>
|
||||
</requestthreads>
|
||||
</searchnode>
|
||||
</tuning>
|
||||
</proton>
|
||||
</engine>
|
||||
|
||||
<config name="vespa.config.search.summary.juniperrc">
|
||||
<max_matches>3</max_matches>
|
||||
<length>750</length>
|
||||
<surround_max>350</surround_max>
|
||||
<min_length>300</min_length>
|
||||
</config>
|
||||
|
||||
|
||||
<min-redundancy>2</min-redundancy>
|
||||
|
||||
</content>
|
||||
</services>
|
||||
@@ -2,3 +2,6 @@
|
||||
# Instead of setting a page to public, we just add this group so that the page
|
||||
# is only accessible to users who have confluence accounts.
|
||||
ALL_CONF_EMAILS_GROUP_NAME = "All_Confluence_Users_Found_By_Onyx"
|
||||
|
||||
VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
|
||||
REQUEST_PAGINATION_LIMIT = 5000
|
||||
|
||||
@@ -4,17 +4,11 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -24,336 +18,18 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
|
||||
_REQUEST_PAGINATION_LIMIT = 5000
|
||||
|
||||
|
||||
def _get_server_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions = confluence_client.get_all_space_permissions_server(
|
||||
space_key=space_key
|
||||
)
|
||||
|
||||
viewspace_permissions = []
|
||||
for permission_category in space_permissions:
|
||||
if permission_category.get("type") == _VIEWSPACE_PERMISSION_TYPE:
|
||||
viewspace_permissions.extend(
|
||||
permission_category.get("spacePermissions", [])
|
||||
)
|
||||
|
||||
is_public = False
|
||||
user_names = set()
|
||||
group_names = set()
|
||||
for permission in viewspace_permissions:
|
||||
user_name = permission.get("userName")
|
||||
if user_name:
|
||||
user_names.add(user_name)
|
||||
group_name = permission.get("groupName")
|
||||
if group_name:
|
||||
group_names.add(group_name)
|
||||
|
||||
# It seems that if anonymous access is turned on for the site and space,
|
||||
# then the space is publicly accessible.
|
||||
# For confluence server, we make a group that contains all users
|
||||
# that exist in confluence and then just add that group to the space permissions
|
||||
# if anonymous access is turned on for the site and space or we set is_public = True
|
||||
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
|
||||
# that we can support confluence server deployments that want anonymous access
|
||||
# to be public (we cant test this because its paywalled)
|
||||
if user_name is None and group_name is None:
|
||||
# Defaults to False
|
||||
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
|
||||
is_public = True
|
||||
else:
|
||||
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
|
||||
|
||||
user_emails = set()
|
||||
for user_name in user_names:
|
||||
user_email = get_user_email_from_username__server(confluence_client, user_name)
|
||||
if user_email:
|
||||
user_emails.add(user_email)
|
||||
else:
|
||||
logger.warning(f"Email for user {user_name} not found in Confluence")
|
||||
|
||||
if not user_emails and not group_names:
|
||||
logger.warning(
|
||||
"No user emails or group names found in Confluence space permissions"
|
||||
f"\nSpace key: {space_key}"
|
||||
f"\nSpace permissions: {space_permissions}"
|
||||
)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_cloud_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions_result = confluence_client.get_space(
|
||||
space_key=space_key, expand="permissions"
|
||||
)
|
||||
space_permissions = space_permissions_result.get("permissions", [])
|
||||
|
||||
user_emails = set()
|
||||
group_names = set()
|
||||
is_externally_public = False
|
||||
for permission in space_permissions:
|
||||
subs = permission.get("subjects")
|
||||
if subs:
|
||||
# If there are subjects, then there are explicit users or groups with access
|
||||
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
|
||||
user_emails.add(email)
|
||||
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
|
||||
group_names.add(group_name)
|
||||
else:
|
||||
# If there are no subjects, then the permission is for everyone
|
||||
if permission.get("operation", {}).get(
|
||||
"operation"
|
||||
) == "read" and permission.get("anonymousAccess", False):
|
||||
# If the permission specifies read access for anonymous users, then
|
||||
# the space is publicly accessible
|
||||
is_externally_public = True
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_externally_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_space_permissions(
|
||||
confluence_client: OnyxConfluence,
|
||||
is_cloud: bool,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
logger.debug("Getting space permissions")
|
||||
# Gets all the spaces in the Confluence instance
|
||||
all_space_keys = []
|
||||
start = 0
|
||||
while True:
|
||||
spaces_batch = confluence_client.get_all_spaces(
|
||||
start=start, limit=_REQUEST_PAGINATION_LIMIT
|
||||
)
|
||||
for space in spaces_batch.get("results", []):
|
||||
all_space_keys.append(space.get("key"))
|
||||
|
||||
if len(spaces_batch.get("results", [])) < _REQUEST_PAGINATION_LIMIT:
|
||||
break
|
||||
|
||||
start += len(spaces_batch.get("results", []))
|
||||
|
||||
# Gets the permissions for each space
|
||||
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
|
||||
for space_key in all_space_keys:
|
||||
if is_cloud:
|
||||
space_permissions = _get_cloud_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
else:
|
||||
space_permissions = _get_server_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
|
||||
# Stores the permissions for each space
|
||||
space_permissions_by_space_key[space_key] = space_permissions
|
||||
logger.info(
|
||||
f"Found space permissions for space '{space_key}': {space_permissions}"
|
||||
)
|
||||
|
||||
return space_permissions_by_space_key
|
||||
|
||||
|
||||
def _extract_read_access_restrictions(
|
||||
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Converts a page's restrictions dict into an ExternalAccess object.
|
||||
If there are no restrictions, then return None
|
||||
"""
|
||||
read_access = restrictions.get("read", {})
|
||||
read_access_restrictions = read_access.get("restrictions", {})
|
||||
|
||||
# Extract the users with read access
|
||||
read_access_user = read_access_restrictions.get("user", {})
|
||||
read_access_user_jsons = read_access_user.get("results", [])
|
||||
read_access_user_emails = []
|
||||
for user in read_access_user_jsons:
|
||||
# If the user has an email, then add it to the list
|
||||
if user.get("email"):
|
||||
read_access_user_emails.append(user["email"])
|
||||
# If the user has a username and not an email, then get the email from Confluence
|
||||
elif user.get("username"):
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client, user_name=user["username"]
|
||||
)
|
||||
if email:
|
||||
read_access_user_emails.append(email)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Email for user {user['username']} not found in Confluence"
|
||||
)
|
||||
else:
|
||||
if user.get("email") is not None:
|
||||
logger.warning(f"Cant find email for user {user.get('displayName')}")
|
||||
logger.warning(
|
||||
"This user needs to make their email accessible in Confluence Settings"
|
||||
)
|
||||
|
||||
logger.warning(f"no user email or username for {user}")
|
||||
|
||||
# Extract the groups with read access
|
||||
read_access_group = read_access_restrictions.get("group", {})
|
||||
read_access_group_jsons = read_access_group.get("results", [])
|
||||
read_access_group_names = [
|
||||
group["name"] for group in read_access_group_jsons if group.get("name")
|
||||
]
|
||||
|
||||
return set(read_access_user_emails), set(read_access_group_names)
|
||||
|
||||
|
||||
def _get_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
perm_sync_data: dict[str, Any],
|
||||
) -> ExternalAccess | None:
|
||||
"""
|
||||
This function gets the restrictions for a page by taking the intersection
|
||||
of the page's restrictions and the restrictions of all the ancestors
|
||||
of the page.
|
||||
If the page/ancestor has no restrictions, then it is ignored (no intersection).
|
||||
If no restrictions are found anywhere, then return None, indicating that the page
|
||||
should inherit the space's restrictions.
|
||||
"""
|
||||
found_user_emails: set[str] = set()
|
||||
found_group_names: set[str] = set()
|
||||
|
||||
found_user_emails, found_group_names = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
|
||||
ancestors: list[dict[str, Any]] = perm_sync_data.get("ancestors", [])
|
||||
for ancestor in ancestors:
|
||||
ancestor_user_emails, ancestor_group_names = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=ancestor.get("restrictions", {}),
|
||||
)
|
||||
if not ancestor_user_emails and not ancestor_group_names:
|
||||
# This ancestor has no restrictions, so it has no effect on
|
||||
# the page's restrictions, so we ignore it
|
||||
continue
|
||||
|
||||
found_user_emails.intersection_update(ancestor_user_emails)
|
||||
found_group_names.intersection_update(ancestor_group_names)
|
||||
|
||||
# If there are no restrictions found, then the page
|
||||
# inherits the space's restrictions so return None
|
||||
if not found_user_emails and not found_group_names:
|
||||
return None
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=found_user_emails,
|
||||
external_user_group_ids=found_group_names,
|
||||
# there is no way for a page to be individually public if the space isn't public
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def _fetch_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
Otherwise, use the space's restrictions.
|
||||
"""
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
)
|
||||
|
||||
if restrictions := _get_all_page_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
perm_sync_data=slim_doc.perm_sync_data,
|
||||
):
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=restrictions,
|
||||
)
|
||||
# If there are restrictions, then we don't need to use the space's restrictions
|
||||
continue
|
||||
|
||||
space_key = slim_doc.perm_sync_data.get("space_key")
|
||||
if not (space_permissions := space_permissions_by_space_key.get(space_key)):
|
||||
logger.debug(
|
||||
f"Individually fetching space permissions for space {space_key}"
|
||||
)
|
||||
try:
|
||||
# If the space permissions are not in the cache, then fetch them
|
||||
if is_cloud:
|
||||
retrieved_space_permissions = _get_cloud_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
else:
|
||||
retrieved_space_permissions = _get_server_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
space_permissions_by_space_key[space_key] = retrieved_space_permissions
|
||||
space_permissions = retrieved_space_permissions
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error fetching space permissions for space {space_key}: {e}"
|
||||
)
|
||||
|
||||
if not space_permissions:
|
||||
logger.warning(
|
||||
f"No permissions found for document {slim_doc.id} in space {space_key}"
|
||||
)
|
||||
continue
|
||||
|
||||
# If there are no restrictions, then use the space's restrictions
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
and not space_permissions.external_user_emails
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"Permissions are empty for document: {slim_doc.id}\n"
|
||||
"This means space permissions are may be wrong for"
|
||||
f" Space key: {space_key}"
|
||||
)
|
||||
|
||||
logger.debug("Finished fetching all page restrictions for space")
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
Fetches document permissions from Confluence and yields DocExternalAccess objects.
|
||||
Compares fetched documents against existing documents in the DB for the connector.
|
||||
If a document exists in the DB but not in the Confluence fetch, it's marked as restricted.
|
||||
"""
|
||||
logger.debug("Starting confluence doc sync")
|
||||
logger.info(f"Starting confluence doc sync for CC Pair ID: {cc_pair.id}")
|
||||
confluence_connector = ConfluenceConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
@@ -363,19 +39,12 @@ def confluence_doc_sync(
|
||||
)
|
||||
confluence_connector.set_credentials_provider(provider)
|
||||
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
|
||||
space_permissions_by_space_key = _get_space_permissions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
is_cloud=is_cloud,
|
||||
)
|
||||
|
||||
slim_docs = []
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
slim_docs: list[SlimDocument] = []
|
||||
logger.info("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents(
|
||||
callback=callback
|
||||
):
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
@@ -384,11 +53,38 @@ def confluence_doc_sync(
|
||||
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
yield from _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
callback=callback,
|
||||
)
|
||||
# Find documents that are no longer accessible in Confluence
|
||||
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id}")
|
||||
existing_doc_ids = fetch_all_existing_docs_fn()
|
||||
|
||||
# Find missing doc IDs
|
||||
fetched_doc_ids = {doc.id for doc in slim_docs}
|
||||
missing_doc_ids = set(existing_doc_ids) - fetched_doc_ids
|
||||
|
||||
# Yield access removal for missing docs. Better to be safe.
|
||||
if missing_doc_ids:
|
||||
logger.warning(
|
||||
f"Found {len(missing_doc_ids)} documents that are in the DB but "
|
||||
"not present in Confluence fetch. Making them inaccessible."
|
||||
)
|
||||
for missing_id in missing_doc_ids:
|
||||
logger.warning(f"Removing access for document ID: {missing_id}")
|
||||
yield DocExternalAccess(
|
||||
doc_id=missing_id,
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
),
|
||||
)
|
||||
|
||||
for doc in slim_docs:
|
||||
if not doc.external_access:
|
||||
raise RuntimeError(f"No external access found for document ID: {doc.id}")
|
||||
|
||||
yield DocExternalAccess(
|
||||
doc_id=doc.id,
|
||||
external_access=doc.external_access,
|
||||
)
|
||||
|
||||
logger.info("Finished confluence doc sync")
|
||||
|
||||
133
backend/ee/onyx/external_permissions/confluence/page_access.py
Normal file
133
backend/ee/onyx/external_permissions/confluence/page_access.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_read_access_restrictions(
|
||||
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
|
||||
) -> tuple[set[str], set[str], bool]:
|
||||
"""
|
||||
Converts a page's restrictions dict into an ExternalAccess object.
|
||||
If there are no restrictions, then return None
|
||||
"""
|
||||
read_access = restrictions.get("read", {})
|
||||
read_access_restrictions = read_access.get("restrictions", {})
|
||||
|
||||
# Extract the users with read access
|
||||
read_access_user = read_access_restrictions.get("user", {})
|
||||
read_access_user_jsons = read_access_user.get("results", [])
|
||||
# any items found means that there is a restriction
|
||||
found_any_restriction = bool(read_access_user_jsons)
|
||||
|
||||
read_access_user_emails = []
|
||||
for user in read_access_user_jsons:
|
||||
# If the user has an email, then add it to the list
|
||||
if user.get("email"):
|
||||
read_access_user_emails.append(user["email"])
|
||||
# If the user has a username and not an email, then get the email from Confluence
|
||||
elif user.get("username"):
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client, user_name=user["username"]
|
||||
)
|
||||
if email:
|
||||
read_access_user_emails.append(email)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Email for user {user['username']} not found in Confluence"
|
||||
)
|
||||
else:
|
||||
if user.get("email") is not None:
|
||||
logger.warning(f"Cant find email for user {user.get('displayName')}")
|
||||
logger.warning(
|
||||
"This user needs to make their email accessible in Confluence Settings"
|
||||
)
|
||||
|
||||
logger.warning(f"no user email or username for {user}")
|
||||
|
||||
# Extract the groups with read access
|
||||
read_access_group = read_access_restrictions.get("group", {})
|
||||
read_access_group_jsons = read_access_group.get("results", [])
|
||||
# any items found means that there is a restriction
|
||||
found_any_restriction |= bool(read_access_group_jsons)
|
||||
read_access_group_names = [
|
||||
group["name"] for group in read_access_group_jsons if group.get("name")
|
||||
]
|
||||
|
||||
return (
|
||||
set(read_access_user_emails),
|
||||
set(read_access_group_names),
|
||||
found_any_restriction,
|
||||
)
|
||||
|
||||
|
||||
def get_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
page_id: str,
|
||||
page_restrictions: dict[str, Any],
|
||||
ancestors: list[dict[str, Any]],
|
||||
) -> ExternalAccess | None:
|
||||
"""
|
||||
This function gets the restrictions for a page. In Confluence, a child can have
|
||||
at MOST the same level accessibility as its immediate parent.
|
||||
|
||||
If no restrictions are found anywhere, then return None, indicating that the page
|
||||
should inherit the space's restrictions.
|
||||
"""
|
||||
found_user_emails: set[str] = set()
|
||||
found_group_names: set[str] = set()
|
||||
|
||||
# NOTE: need the found_any_restriction, since we can find restrictions
|
||||
# but not be able to extract any user emails or group names
|
||||
# in this case, we should just give no access
|
||||
found_user_emails, found_group_names, found_any_page_level_restriction = (
|
||||
_extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=page_restrictions,
|
||||
)
|
||||
)
|
||||
# if there are individual page-level restrictions, then this is the accurate
|
||||
# restriction for the page. You cannot both have page-level restrictions AND
|
||||
# inherit restrictions from the parent.
|
||||
if found_any_page_level_restriction:
|
||||
return ExternalAccess(
|
||||
external_user_emails=found_user_emails,
|
||||
external_user_group_ids=found_group_names,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
# ancestors seem to be in order from root to immediate parent
|
||||
# https://community.atlassian.com/forums/Confluence-questions/Order-of-ancestors-in-REST-API-response-Confluence-Server-amp/qaq-p/2385981
|
||||
# we want the restrictions from the immediate parent to take precedence, so we should
|
||||
# reverse the list
|
||||
for ancestor in reversed(ancestors):
|
||||
(
|
||||
ancestor_user_emails,
|
||||
ancestor_group_names,
|
||||
found_any_restrictions_in_ancestor,
|
||||
) = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=ancestor.get("restrictions", {}),
|
||||
)
|
||||
if found_any_restrictions_in_ancestor:
|
||||
# if inheriting restrictions from the parent, then the first one we run into
|
||||
# should be applied (the reason why we'd traverse more than one ancestor is if
|
||||
# the ancestor also is in "inherit" mode.)
|
||||
logger.debug(
|
||||
f"Found user restrictions {ancestor_user_emails} and group restrictions {ancestor_group_names}"
|
||||
f"for document {page_id} based on ancestor {ancestor}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails=ancestor_user_emails,
|
||||
external_user_group_ids=ancestor_group_names,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
# we didn't find any restrictions, so the page inherits the space's restrictions
|
||||
return None
|
||||
165
backend/ee/onyx/external_permissions/confluence/space_access.py
Normal file
165
backend/ee/onyx/external_permissions/confluence/space_access.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from ee.onyx.external_permissions.confluence.constants import REQUEST_PAGINATION_LIMIT
|
||||
from ee.onyx.external_permissions.confluence.constants import VIEWSPACE_PERMISSION_TYPE
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_server_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions = confluence_client.get_all_space_permissions_server(
|
||||
space_key=space_key
|
||||
)
|
||||
|
||||
viewspace_permissions = []
|
||||
for permission_category in space_permissions:
|
||||
if permission_category.get("type") == VIEWSPACE_PERMISSION_TYPE:
|
||||
viewspace_permissions.extend(
|
||||
permission_category.get("spacePermissions", [])
|
||||
)
|
||||
|
||||
is_public = False
|
||||
user_names = set()
|
||||
group_names = set()
|
||||
for permission in viewspace_permissions:
|
||||
if user_name := permission.get("userName"):
|
||||
user_names.add(user_name)
|
||||
if group_name := permission.get("groupName"):
|
||||
group_names.add(group_name)
|
||||
|
||||
# It seems that if anonymous access is turned on for the site and space,
|
||||
# then the space is publicly accessible.
|
||||
# For confluence server, we make a group that contains all users
|
||||
# that exist in confluence and then just add that group to the space permissions
|
||||
# if anonymous access is turned on for the site and space or we set is_public = True
|
||||
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
|
||||
# that we can support confluence server deployments that want anonymous access
|
||||
# to be public (we cant test this because its paywalled)
|
||||
if user_name is None and group_name is None:
|
||||
# Defaults to False
|
||||
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
|
||||
is_public = True
|
||||
else:
|
||||
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
|
||||
|
||||
user_emails = set()
|
||||
for user_name in user_names:
|
||||
user_email = get_user_email_from_username__server(confluence_client, user_name)
|
||||
if user_email:
|
||||
user_emails.add(user_email)
|
||||
else:
|
||||
logger.warning(f"Email for user {user_name} not found in Confluence")
|
||||
|
||||
if not user_emails and not group_names:
|
||||
logger.warning(
|
||||
"No user emails or group names found in Confluence space permissions"
|
||||
f"\nSpace key: {space_key}"
|
||||
f"\nSpace permissions: {space_permissions}"
|
||||
)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_cloud_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions_result = confluence_client.get_space(
|
||||
space_key=space_key, expand="permissions"
|
||||
)
|
||||
space_permissions = space_permissions_result.get("permissions", [])
|
||||
|
||||
user_emails = set()
|
||||
group_names = set()
|
||||
is_externally_public = False
|
||||
for permission in space_permissions:
|
||||
subs = permission.get("subjects")
|
||||
if subs:
|
||||
# If there are subjects, then there are explicit users or groups with access
|
||||
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
|
||||
user_emails.add(email)
|
||||
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
|
||||
group_names.add(group_name)
|
||||
else:
|
||||
# If there are no subjects, then the permission is for everyone
|
||||
if permission.get("operation", {}).get(
|
||||
"operation"
|
||||
) == "read" and permission.get("anonymousAccess", False):
|
||||
# If the permission specifies read access for anonymous users, then
|
||||
# the space is publicly accessible
|
||||
is_externally_public = True
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_externally_public,
|
||||
)
|
||||
|
||||
|
||||
def get_space_permission(
|
||||
confluence_client: OnyxConfluence,
|
||||
space_key: str,
|
||||
is_cloud: bool,
|
||||
) -> ExternalAccess:
|
||||
if is_cloud:
|
||||
space_permissions = _get_cloud_space_permissions(confluence_client, space_key)
|
||||
else:
|
||||
space_permissions = _get_server_space_permissions(confluence_client, space_key)
|
||||
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
and not space_permissions.external_user_emails
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"No permissions found for space '{space_key}'. This is very unlikely"
|
||||
"to be correct and is more likely caused by an access token with"
|
||||
"insufficient permissions. Make sure that the access token has Admin"
|
||||
f"permissions for space '{space_key}'"
|
||||
)
|
||||
|
||||
return space_permissions
|
||||
|
||||
|
||||
def get_all_space_permissions(
|
||||
confluence_client: OnyxConfluence,
|
||||
is_cloud: bool,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
logger.debug("Getting space permissions")
|
||||
# Gets all the spaces in the Confluence instance
|
||||
all_space_keys = []
|
||||
start = 0
|
||||
while True:
|
||||
spaces_batch = confluence_client.get_all_spaces(
|
||||
start=start, limit=REQUEST_PAGINATION_LIMIT
|
||||
)
|
||||
for space in spaces_batch.get("results", []):
|
||||
all_space_keys.append(space.get("key"))
|
||||
|
||||
if len(spaces_batch.get("results", [])) < REQUEST_PAGINATION_LIMIT:
|
||||
break
|
||||
|
||||
start += len(spaces_batch.get("results", []))
|
||||
|
||||
# Gets the permissions for each space
|
||||
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
|
||||
for space_key in all_space_keys:
|
||||
space_permissions = get_space_permission(confluence_client, space_key, is_cloud)
|
||||
|
||||
# Stores the permissions for each space
|
||||
space_permissions_by_space_key[space_key] = space_permissions
|
||||
|
||||
return space_permissions_by_space_key
|
||||
@@ -2,8 +2,8 @@ from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -34,6 +34,7 @@ def _get_slim_doc_generator(
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
@@ -57,17 +58,11 @@ def gmail_doc_sync(
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
if slim_doc.external_access is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
if user_email := slim_doc.perm_sync_data.get("user_email"):
|
||||
ext_access = ExternalAccess(
|
||||
external_user_emails=set([user_email]),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=slim_doc.external_access,
|
||||
)
|
||||
|
||||
@@ -1,23 +1,25 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from ee.onyx.external_permissions.google_drive.models import PermissionType
|
||||
from ee.onyx.external_permissions.google_drive.permission_retrieval import (
|
||||
get_permissions_by_ids,
|
||||
)
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _get_slim_doc_generator(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
@@ -38,105 +40,87 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
|
||||
def _fetch_permissions_for_permission_ids(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
permission_ids: list[str],
|
||||
permission_info: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
doc_id = permission_info.get("doc_id")
|
||||
if not permission_info or not doc_id:
|
||||
return []
|
||||
|
||||
permissions = [
|
||||
_PERMISSION_ID_PERMISSION_MAP[pid]
|
||||
for pid in permission_ids
|
||||
if pid in _PERMISSION_ID_PERMISSION_MAP
|
||||
]
|
||||
|
||||
if len(permissions) == len(permission_ids):
|
||||
return permissions
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
|
||||
drive_service = get_drive_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# We continue on 404 or 403 because the document may not exist or the user may not have access to it
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=doc_id,
|
||||
fields="permissions(id, emailAddress, type, domain)",
|
||||
supportsAllDrives=True,
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
for permission in fetched_permissions:
|
||||
permissions_for_doc_id.append(permission)
|
||||
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
|
||||
|
||||
return permissions_for_doc_id
|
||||
|
||||
|
||||
def _get_permissions_from_slim_doc(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
slim_doc: SlimDocument,
|
||||
def get_external_access_for_raw_gdrive_file(
|
||||
file: GoogleDriveFileType, company_domain: str, drive_service: GoogleDriveService
|
||||
) -> ExternalAccess:
|
||||
permission_info = slim_doc.perm_sync_data or {}
|
||||
"""
|
||||
Get the external access for a raw Google Drive file.
|
||||
|
||||
permissions_list = permission_info.get("permissions", [])
|
||||
if not permissions_list:
|
||||
if permission_ids := permission_info.get("permission_ids"):
|
||||
permissions_list = _fetch_permissions_for_permission_ids(
|
||||
google_drive_connector=google_drive_connector,
|
||||
permission_ids=permission_ids,
|
||||
permission_info=permission_info,
|
||||
)
|
||||
if not permissions_list:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
Assumes the file we retrieved has EITHER `permissions` or `permission_ids`
|
||||
"""
|
||||
doc_id = file.get("id")
|
||||
if not doc_id:
|
||||
raise ValueError("No doc_id found in file")
|
||||
|
||||
company_domain = google_drive_connector.google_domain
|
||||
permissions = file.get("permissions")
|
||||
permission_ids = file.get("permissionIds")
|
||||
drive_id = file.get("driveId")
|
||||
|
||||
permissions_list: list[GoogleDrivePermission] = []
|
||||
if permissions:
|
||||
permissions_list = [
|
||||
GoogleDrivePermission.from_drive_permission(p) for p in permissions
|
||||
]
|
||||
elif permission_ids:
|
||||
permissions_list = get_permissions_by_ids(
|
||||
drive_service=drive_service,
|
||||
doc_id=doc_id,
|
||||
permission_ids=permission_ids,
|
||||
)
|
||||
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
public = False
|
||||
skipped_permissions = 0
|
||||
|
||||
for permission in permissions_list:
|
||||
if not permission:
|
||||
skipped_permissions += 1
|
||||
continue
|
||||
# if the permission is inherited, do not add it directly to the file
|
||||
# instead, add the folder ID as a group that has access to the file
|
||||
# we will then handle mapping that folder to the list of Onyx users
|
||||
# in the group sync job
|
||||
# NOTE: this doesn't handle the case where a folder initially has no
|
||||
# permissioning, but then later that folder is shared with a user or group.
|
||||
# We could fetch all ancestors of the file to get the list of folders that
|
||||
# might affect the permissions of the file, but this will get replaced with
|
||||
# an audit-log based approach in the future so not doing it now.
|
||||
if permission.inherited_from:
|
||||
folder_ids_to_inherit_permissions_from.add(permission.inherited_from)
|
||||
|
||||
permission_type = permission["type"]
|
||||
if permission_type == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
elif permission_type == "group":
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission_type == "domain" and company_domain:
|
||||
if permission.get("domain") == company_domain:
|
||||
if permission.type == PermissionType.USER:
|
||||
if permission.email_address:
|
||||
user_emails.add(permission.email_address)
|
||||
else:
|
||||
logger.error(
|
||||
"Permission is type `user` but no email address is "
|
||||
f"provided for document {doc_id}"
|
||||
f"\n {permission}"
|
||||
)
|
||||
elif permission.type == PermissionType.GROUP:
|
||||
# groups are represented as email addresses within Drive
|
||||
if permission.email_address:
|
||||
group_emails.add(permission.email_address)
|
||||
else:
|
||||
logger.error(
|
||||
"Permission is type `group` but no email address is "
|
||||
f"provided for document {doc_id}"
|
||||
f"\n {permission}"
|
||||
)
|
||||
elif permission.type == PermissionType.DOMAIN and company_domain:
|
||||
if permission.domain == company_domain:
|
||||
public = True
|
||||
else:
|
||||
logger.warning(
|
||||
"Permission is type domain but does not match company domain:"
|
||||
f"\n {permission}"
|
||||
)
|
||||
elif permission_type == "anyone":
|
||||
elif permission.type == PermissionType.ANYONE:
|
||||
public = True
|
||||
|
||||
if skipped_permissions > 0:
|
||||
logger.warning(
|
||||
f"Skipped {skipped_permissions} permissions of {len(permissions_list)} for document {slim_doc.id}"
|
||||
)
|
||||
|
||||
drive_id = permission_info.get("drive_id")
|
||||
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
|
||||
group_ids = (
|
||||
group_emails
|
||||
| folder_ids_to_inherit_permissions_from
|
||||
| ({drive_id} if drive_id is not None else set())
|
||||
)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
@@ -147,6 +131,7 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
@@ -162,7 +147,9 @@ def gdrive_doc_sync(
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
|
||||
|
||||
total_processed = 0
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
logger.info(f"Drive perm sync: Processing {len(slim_doc_batch)} documents")
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
@@ -170,11 +157,14 @@ def gdrive_doc_sync(
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
)
|
||||
if slim_doc.external_access is None:
|
||||
raise ValueError(
|
||||
f"Drive perm sync: No external access for document {slim_doc.id}"
|
||||
)
|
||||
|
||||
yield DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
external_access=slim_doc.external_access,
|
||||
doc_id=slim_doc.id,
|
||||
)
|
||||
total_processed += len(slim_doc_batch)
|
||||
logger.info(f"Drive perm sync: Processed {total_processed} total documents")
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
from collections.abc import Iterator
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from ee.onyx.external_permissions.google_drive.permission_retrieval import (
|
||||
get_permissions_by_ids,
|
||||
)
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from onyx.connectors.google_drive.file_retrieval import generate_time_range_filter
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Only include fields we need - folder ID and permissions
|
||||
# IMPORTANT: must fetch permissionIds, since sometimes the drive API
|
||||
# seems to miss permissions when requesting them directly
|
||||
FOLDER_PERMISSION_FIELDS = (
|
||||
"nextPageToken, files(id, name, permissionIds, "
|
||||
"permissions(id, emailAddress, type, domain, permissionDetails))"
|
||||
)
|
||||
|
||||
|
||||
def get_folder_permissions_by_ids(
|
||||
service: Resource,
|
||||
folder_id: str,
|
||||
permission_ids: list[str],
|
||||
) -> list[GoogleDrivePermission]:
|
||||
"""
|
||||
Retrieves permissions for a specific folder filtered by permission IDs.
|
||||
|
||||
Args:
|
||||
service: The Google Drive service instance
|
||||
folder_id: The ID of the folder to fetch permissions for
|
||||
permission_ids: A list of permission IDs to filter by
|
||||
|
||||
Returns:
|
||||
A list of permissions matching the provided permission IDs
|
||||
"""
|
||||
return get_permissions_by_ids(
|
||||
drive_service=service,
|
||||
doc_id=folder_id,
|
||||
permission_ids=permission_ids,
|
||||
)
|
||||
|
||||
|
||||
def get_modified_folders(
|
||||
service: Resource,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
"""
|
||||
Retrieves all folders that were modified within the specified time range.
|
||||
Only includes folder ID and permission information, not any contained files.
|
||||
|
||||
Args:
|
||||
service: The Google Drive service instance
|
||||
start: The start time as seconds since Unix epoch (inclusive)
|
||||
end: The end time as seconds since Unix epoch (inclusive)
|
||||
|
||||
Returns:
|
||||
An iterator yielding folder information including ID and permissions
|
||||
"""
|
||||
# Build query for folders
|
||||
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
query += " and trashed = false"
|
||||
query += generate_time_range_filter(start, end)
|
||||
|
||||
# Retrieve and yield folders
|
||||
for folder in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
includePermissionsForView="published",
|
||||
fields=FOLDER_PERMISSION_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
yield folder
|
||||
@@ -1,6 +1,15 @@
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.google_drive.folder_retrieval import (
|
||||
get_folder_permissions_by_ids,
|
||||
)
|
||||
from ee.onyx.external_permissions.google_drive.folder_retrieval import (
|
||||
get_modified_folders,
|
||||
)
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from ee.onyx.external_permissions.google_drive.models import PermissionType
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.resources import AdminService
|
||||
@@ -12,6 +21,87 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
"""
|
||||
Folder Permission Sync.
|
||||
|
||||
Each folder is treated as a group. Each file has all ancestor folders
|
||||
as groups.
|
||||
"""
|
||||
|
||||
|
||||
class FolderInfo(BaseModel):
|
||||
id: str
|
||||
permissions: list[GoogleDrivePermission]
|
||||
|
||||
|
||||
def _get_all_folders(
|
||||
google_drive_connector: GoogleDriveConnector, skip_folders_without_permissions: bool
|
||||
) -> list[FolderInfo]:
|
||||
"""Have to get all folders since the group syncing system assumes all groups
|
||||
are returned every time.
|
||||
|
||||
TODO: tweak things so we can fetch deltas.
|
||||
"""
|
||||
all_folders: list[FolderInfo] = []
|
||||
seen_folder_ids: set[str] = set()
|
||||
|
||||
user_emails = google_drive_connector._get_all_user_emails()
|
||||
for user_email in user_emails:
|
||||
drive_service = get_drive_service(
|
||||
google_drive_connector.creds,
|
||||
user_email,
|
||||
)
|
||||
|
||||
for folder in get_modified_folders(
|
||||
service=drive_service,
|
||||
):
|
||||
folder_id = folder["id"]
|
||||
if folder_id in seen_folder_ids:
|
||||
logger.debug(f"Folder {folder_id} has already been seen. Skipping.")
|
||||
continue
|
||||
|
||||
seen_folder_ids.add(folder_id)
|
||||
|
||||
# Check if the folder has permission IDs but no permissions
|
||||
permission_ids = folder.get("permissionIds", [])
|
||||
raw_permissions = folder.get("permissions", [])
|
||||
|
||||
if not raw_permissions and permission_ids:
|
||||
# Fetch permissions using the IDs
|
||||
permissions = get_folder_permissions_by_ids(
|
||||
drive_service, folder_id, permission_ids
|
||||
)
|
||||
else:
|
||||
permissions = [
|
||||
GoogleDrivePermission.from_drive_permission(permission)
|
||||
for permission in raw_permissions
|
||||
]
|
||||
|
||||
# Don't include inherited permissions, those will be captured
|
||||
# by the folder/shared drive itself
|
||||
permissions = [
|
||||
permission
|
||||
for permission in permissions
|
||||
if permission.inherited_from is None
|
||||
]
|
||||
|
||||
if not permissions and skip_folders_without_permissions:
|
||||
logger.debug(f"Folder {folder_id} has no permissions. Skipping.")
|
||||
continue
|
||||
|
||||
all_folders.append(
|
||||
FolderInfo(
|
||||
id=folder_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
)
|
||||
|
||||
return all_folders
|
||||
|
||||
|
||||
"""Individual Shared Drive / My Drive Permission Sync"""
|
||||
|
||||
|
||||
def _get_drive_members(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
admin_service: AdminService,
|
||||
@@ -51,15 +141,17 @@ def _get_drive_members(
|
||||
drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=drive_id,
|
||||
fields="permissions(emailAddress, type)",
|
||||
fields="permissions(emailAddress, type),nextPageToken",
|
||||
supportsAllDrives=True,
|
||||
# can only set `useDomainAdminAccess` to true if the user
|
||||
# is an admin
|
||||
useDomainAdminAccess=is_admin,
|
||||
):
|
||||
if permission["type"] == "group":
|
||||
# NOTE: don't need to check for PermissionType.ANYONE since
|
||||
# you can't share a drive with the internet
|
||||
if permission["type"] == PermissionType.GROUP:
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission["type"] == "user":
|
||||
elif permission["type"] == PermissionType.USER:
|
||||
user_emails.add(permission["emailAddress"])
|
||||
except HttpError as e:
|
||||
if e.status_code == 404:
|
||||
@@ -87,7 +179,7 @@ def _get_all_groups(
|
||||
admin_service.groups().list,
|
||||
list_key="groups",
|
||||
domain=google_domain,
|
||||
fields="groups(email)",
|
||||
fields="groups(email),nextPageToken",
|
||||
):
|
||||
group_emails.add(group["email"])
|
||||
return group_emails
|
||||
@@ -107,7 +199,7 @@ def _map_group_email_to_member_emails(
|
||||
admin_service.members().list,
|
||||
list_key="members",
|
||||
groupKey=group_email,
|
||||
fields="members(email)",
|
||||
fields="members(email),nextPageToken",
|
||||
):
|
||||
group_member_emails.add(member["email"])
|
||||
|
||||
@@ -118,6 +210,7 @@ def _map_group_email_to_member_emails(
|
||||
def _build_onyx_groups(
|
||||
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
|
||||
group_email_to_member_emails_map: dict[str, set[str]],
|
||||
folder_info: list[FolderInfo],
|
||||
) -> list[ExternalUserGroup]:
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
|
||||
@@ -125,13 +218,52 @@ def _build_onyx_groups(
|
||||
# This is because having drive level access means you have
|
||||
# irrevocable access to all the files in the drive.
|
||||
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
|
||||
all_member_emails: set[str] = user_emails
|
||||
drive_member_emails: set[str] = user_emails
|
||||
for group_email in group_emails:
|
||||
all_member_emails.update(group_email_to_member_emails_map[group_email])
|
||||
if group_email not in group_email_to_member_emails_map:
|
||||
logger.warning(
|
||||
f"Group email {group_email} for drive {drive_id} not found in "
|
||||
"group_email_to_member_emails_map"
|
||||
)
|
||||
continue
|
||||
drive_member_emails.update(group_email_to_member_emails_map[group_email])
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=drive_id,
|
||||
user_emails=list(all_member_emails),
|
||||
user_emails=list(drive_member_emails),
|
||||
)
|
||||
)
|
||||
|
||||
# Convert all folder permissions to onyx groups
|
||||
for folder in folder_info:
|
||||
anyone_can_access = False
|
||||
folder_member_emails: set[str] = set()
|
||||
for permission in folder.permissions:
|
||||
if permission.type == PermissionType.USER:
|
||||
if permission.email_address is None:
|
||||
logger.warning(
|
||||
f"User email is None for folder {folder.id} permission {permission}"
|
||||
)
|
||||
continue
|
||||
folder_member_emails.add(permission.email_address)
|
||||
elif permission.type == PermissionType.GROUP:
|
||||
if permission.email_address not in group_email_to_member_emails_map:
|
||||
logger.warning(
|
||||
f"Group email {permission.email_address} for folder {folder.id} "
|
||||
"not found in group_email_to_member_emails_map"
|
||||
)
|
||||
continue
|
||||
folder_member_emails.update(
|
||||
group_email_to_member_emails_map[permission.email_address]
|
||||
)
|
||||
elif permission.type == PermissionType.ANYONE:
|
||||
anyone_can_access = True
|
||||
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=folder.id,
|
||||
user_emails=list(folder_member_emails),
|
||||
gives_anyone_access=anyone_can_access,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -168,6 +300,12 @@ def gdrive_group_sync(
|
||||
admin_service, google_drive_connector.google_domain
|
||||
)
|
||||
|
||||
# Get all folder permissions
|
||||
folder_info = _get_all_folders(
|
||||
google_drive_connector=google_drive_connector,
|
||||
skip_folders_without_permissions=True,
|
||||
)
|
||||
|
||||
# Map group emails to their members
|
||||
group_email_to_member_emails_map = _map_group_email_to_member_emails(
|
||||
admin_service, all_group_emails
|
||||
@@ -177,6 +315,7 @@ def gdrive_group_sync(
|
||||
onyx_groups = _build_onyx_groups(
|
||||
drive_id_to_members_map=drive_id_to_members_map,
|
||||
group_email_to_member_emails_map=group_email_to_member_emails_map,
|
||||
folder_info=folder_info,
|
||||
)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
64
backend/ee/onyx/external_permissions/google_drive/models.py
Normal file
64
backend/ee/onyx/external_permissions/google_drive/models.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PermissionType(str, Enum):
|
||||
USER = "user"
|
||||
GROUP = "group"
|
||||
DOMAIN = "domain"
|
||||
ANYONE = "anyone"
|
||||
|
||||
|
||||
class GoogleDrivePermissionDetails(BaseModel):
|
||||
# this is "file", "member", etc.
|
||||
# different from the `type` field within `GoogleDrivePermission`
|
||||
# Sometimes can be not, although not sure why...
|
||||
permission_type: str | None
|
||||
# this is "reader", "writer", "owner", etc.
|
||||
role: str
|
||||
# this is the id of the parent permission
|
||||
inherited_from: str | None
|
||||
|
||||
|
||||
class GoogleDrivePermission(BaseModel):
|
||||
id: str
|
||||
# groups are also represented as email addresses within Drive
|
||||
# will be None for domain/global permissions
|
||||
email_address: str | None
|
||||
type: PermissionType
|
||||
domain: str | None # only applies to domain permissions
|
||||
permission_details: GoogleDrivePermissionDetails | None
|
||||
|
||||
@classmethod
|
||||
def from_drive_permission(
|
||||
cls, drive_permission: dict[str, Any]
|
||||
) -> "GoogleDrivePermission":
|
||||
# we seem to only get details for permissions that are inherited
|
||||
# we can get multiple details if a permission is inherited from multiple
|
||||
permission_details_list = drive_permission.get("permissionDetails", [])
|
||||
permission_details: dict[str, Any] | None = (
|
||||
permission_details_list[0] if permission_details_list else None
|
||||
)
|
||||
return cls(
|
||||
id=drive_permission["id"],
|
||||
email_address=drive_permission.get("emailAddress"),
|
||||
type=PermissionType(drive_permission["type"]),
|
||||
domain=drive_permission.get("domain"),
|
||||
permission_details=(
|
||||
GoogleDrivePermissionDetails(
|
||||
permission_type=permission_details.get("type"),
|
||||
role=permission_details.get("role", ""),
|
||||
inherited_from=permission_details.get("inheritedFrom"),
|
||||
)
|
||||
if permission_details
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def inherited_from(self) -> str | None:
|
||||
if self.permission_details:
|
||||
return self.permission_details.inherited_from
|
||||
return None
|
||||
@@ -0,0 +1,62 @@
|
||||
from retry import retry
|
||||
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@retry(tries=3, delay=2, backoff=2)
|
||||
def get_permissions_by_ids(
|
||||
drive_service: GoogleDriveService,
|
||||
doc_id: str,
|
||||
permission_ids: list[str],
|
||||
) -> list[GoogleDrivePermission]:
|
||||
"""
|
||||
Fetches permissions for a document based on a list of permission IDs.
|
||||
|
||||
Args:
|
||||
drive_service: The Google Drive service instance
|
||||
doc_id: The ID of the document to fetch permissions for
|
||||
permission_ids: A list of permission IDs to filter by
|
||||
|
||||
Returns:
|
||||
A list of GoogleDrivePermission objects matching the provided permission IDs
|
||||
"""
|
||||
if not permission_ids:
|
||||
return []
|
||||
|
||||
# Create a set for faster lookup
|
||||
permission_id_set = set(permission_ids)
|
||||
|
||||
# Fetch all permissions for the document
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=doc_id,
|
||||
fields="permissions(id, emailAddress, type, domain, permissionDetails),nextPageToken",
|
||||
supportsAllDrives=True,
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
|
||||
# Filter permissions by ID and convert to GoogleDrivePermission objects
|
||||
filtered_permissions = []
|
||||
for permission in fetched_permissions:
|
||||
permission_id = permission.get("id")
|
||||
if permission_id in permission_id_set:
|
||||
google_drive_permission = GoogleDrivePermission.from_drive_permission(
|
||||
permission
|
||||
)
|
||||
filtered_permissions.append(google_drive_permission)
|
||||
|
||||
# Log if we couldn't find all requested permission IDs
|
||||
if len(filtered_permissions) < len(permission_ids):
|
||||
missing_ids = permission_id_set - {p.id for p in filtered_permissions if p.id}
|
||||
logger.warning(
|
||||
f"Could not find all requested permission IDs for document {doc_id}. "
|
||||
f"Missing IDs: {missing_ids}"
|
||||
)
|
||||
|
||||
return filtered_permissions
|
||||
49
backend/ee/onyx/external_permissions/perm_sync_types.py
Normal file
49
backend/ee/onyx/external_permissions/perm_sync_types.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
|
||||
# Avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
|
||||
from onyx.access.models import DocExternalAccess # noqa
|
||||
from onyx.db.models import ConnectorCredentialPair # noqa
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
|
||||
|
||||
|
||||
class FetchAllDocumentsFunction(Protocol):
|
||||
"""Protocol for a function that fetches all document IDs for a connector credential pair."""
|
||||
|
||||
def __call__(self) -> list[str]:
|
||||
"""
|
||||
Returns a list of document IDs for a connector credential pair.
|
||||
|
||||
This is typically used to determine which documents should no longer be
|
||||
accessible during the document sync process.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
"ConnectorCredentialPair",
|
||||
FetchAllDocumentsFunction,
|
||||
Optional["IndexingHeartbeatInterface"],
|
||||
],
|
||||
Generator["DocExternalAccess", None, None],
|
||||
]
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
str,
|
||||
"ConnectorCredentialPair",
|
||||
],
|
||||
list["ExternalUserGroup"],
|
||||
]
|
||||
|
||||
# list of chunks to be censored and the user email. returns censored chunks
|
||||
CensoringFuncType = Callable[[list[InferenceChunk], str], list[InferenceChunk]]
|
||||
@@ -1,9 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import get_all_censoring_enabled_sources
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.pipeline import InferenceChunk
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
@@ -12,32 +9,25 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
|
||||
DocumentSource,
|
||||
# list of chunks to be censored and the user email. returns censored chunks
|
||||
Callable[[list[InferenceChunk], str], list[InferenceChunk]],
|
||||
] = {
|
||||
DocumentSource.SALESFORCE: censor_salesforce_chunks,
|
||||
}
|
||||
|
||||
|
||||
def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
|
||||
"""
|
||||
Returns the set of sources that have censoring enabled.
|
||||
This is based on if the access_type is set to sync and the connector
|
||||
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.
|
||||
source has a censoring config.
|
||||
|
||||
NOTE: This means if there is a source has a single cc_pair that is sync,
|
||||
all chunks for that source will be censored, even if the connector that
|
||||
indexed that chunk is not sync. This was done to avoid getting the cc_pair
|
||||
for every single chunk.
|
||||
"""
|
||||
all_censoring_enabled_sources = get_all_censoring_enabled_sources()
|
||||
with get_session_context_manager() as db_session:
|
||||
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
|
||||
return {
|
||||
cc_pair.connector.source
|
||||
for cc_pair in enabled_sync_connectors
|
||||
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
if cc_pair.connector.source in all_censoring_enabled_sources
|
||||
}
|
||||
|
||||
|
||||
@@ -70,7 +60,11 @@ def _post_query_chunk_censoring(
|
||||
# check function for that source
|
||||
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
|
||||
for source, chunks_for_source in chunks_to_process.items():
|
||||
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
|
||||
sync_config = get_source_perm_sync_config(source)
|
||||
if sync_config is None or sync_config.censoring_config is None:
|
||||
raise ValueError(f"No sync config found for {source}")
|
||||
|
||||
censor_chunks_for_source = sync_config.censoring_config.chunk_censoring_func
|
||||
try:
|
||||
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
|
||||
except Exception as e:
|
||||
|
||||
63
backend/ee/onyx/external_permissions/slack/channel_access.py
Normal file
63
backend/ee/onyx/external_permissions/slack/channel_access.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.slack.connector import ChannelType
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
|
||||
|
||||
def get_channel_access(
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get channel access permissions for a Slack channel.
|
||||
|
||||
Args:
|
||||
client: Slack WebClient instance
|
||||
channel: Slack channel object containing channel info
|
||||
user_cache: Cache of user IDs to BasicExpertInfo objects. May be updated in place.
|
||||
|
||||
Returns:
|
||||
ExternalAccess object for the channel.
|
||||
"""
|
||||
channel_is_public = not channel["is_private"]
|
||||
if channel_is_public:
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
channel_id = channel["id"]
|
||||
|
||||
# Get all member IDs for the channel
|
||||
member_ids = []
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_members,
|
||||
channel=channel_id,
|
||||
):
|
||||
member_ids.extend(result.get("members", []))
|
||||
|
||||
member_emails = set()
|
||||
for member_id in member_ids:
|
||||
# Try to get user info from cache or fetch it
|
||||
user_info = expert_info_from_slack_id(
|
||||
user_id=member_id,
|
||||
client=client,
|
||||
user_cache=user_cache,
|
||||
)
|
||||
|
||||
# If we have user info and an email, add it to the set
|
||||
if user_info and user_info.email:
|
||||
member_emails.add(user_info.email)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=member_emails,
|
||||
# NOTE: groups are not used, since adding a group to a channel just adds all
|
||||
# users that are in the group.
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
@@ -2,15 +2,17 @@ from collections.abc import Generator
|
||||
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -28,7 +30,7 @@ def _fetch_workspace_permissions(
|
||||
external_user_emails=user_emails,
|
||||
# No group<->document mapping for slack
|
||||
external_user_group_ids=set(),
|
||||
# No way to determine if slack is invite only without enterprise liscense
|
||||
# No way to determine if slack is invite only without enterprise license
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
@@ -62,7 +64,7 @@ def _fetch_channel_permissions(
|
||||
for channel_id in private_channel_ids:
|
||||
# Collect all member ids for the channel pagination calls
|
||||
member_ids = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.conversations_members,
|
||||
channel=channel_id,
|
||||
):
|
||||
@@ -90,7 +92,7 @@ def _fetch_channel_permissions(
|
||||
external_user_emails=member_emails,
|
||||
# No group<->document mapping for slack
|
||||
external_user_group_ids=set(),
|
||||
# No way to determine if slack is invite only without enterprise liscense
|
||||
# No way to determine if slack is invite only without enterprise license
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
@@ -98,27 +100,23 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def _get_slack_document_access(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
slack_connector: SlackConnector,
|
||||
channel_permissions: dict[str, ExternalAccess],
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
|
||||
# Use credentials provider instead of directly loading credentials
|
||||
provider = OnyxDBCredentialsProvider(
|
||||
get_current_tenant_id(), "slack", cc_pair.credential.id
|
||||
)
|
||||
slack_connector.set_credentials_provider(provider)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
if doc_metadata.external_access is None:
|
||||
raise ValueError(
|
||||
f"No external access for document {doc_metadata.id}. "
|
||||
"Please check to make sure that your Slack bot token has the "
|
||||
"`channels:read` scope"
|
||||
)
|
||||
|
||||
yield DocExternalAccess(
|
||||
external_access=channel_permissions[channel_id],
|
||||
external_access=doc_metadata.external_access,
|
||||
doc_id=doc_metadata.id,
|
||||
)
|
||||
|
||||
@@ -131,6 +129,7 @@ def _get_slack_document_access(
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
@@ -139,9 +138,18 @@ def slack_doc_sync(
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
slack_client = WebClient(
|
||||
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||
# Use credentials provider instead of directly loading credentials
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
slack_client = SlackConnector.make_slack_web_client(
|
||||
provider.get_provider_key(),
|
||||
cc_pair.credential.credential_json["slack_bot_token"],
|
||||
SlackConnector.MAX_RETRIES,
|
||||
r,
|
||||
)
|
||||
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
if not user_id_to_email_map:
|
||||
raise ValueError(
|
||||
@@ -158,8 +166,11 @@ def slack_doc_sync(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.set_credentials_provider(provider)
|
||||
|
||||
yield from _get_slack_document_access(
|
||||
cc_pair=cc_pair,
|
||||
slack_connector,
|
||||
channel_permissions=channel_permissions,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
@@ -9,8 +9,11 @@ from slack_sdk import WebClient
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -20,7 +23,7 @@ def _get_slack_group_ids(
|
||||
slack_client: WebClient,
|
||||
) -> list[str]:
|
||||
group_ids = []
|
||||
for result in make_paginated_slack_api_call_w_retries(slack_client.usergroups_list):
|
||||
for result in make_paginated_slack_api_call(slack_client.usergroups_list):
|
||||
for group in result.get("usergroups", []):
|
||||
group_ids.append(group.get("id"))
|
||||
return group_ids
|
||||
@@ -32,7 +35,7 @@ def _get_slack_group_members_email(
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> list[str]:
|
||||
group_member_emails = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.usergroups_users_list, usergroup=group_name
|
||||
):
|
||||
for member_id in result.get("users", []):
|
||||
@@ -55,9 +58,18 @@ def slack_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
slack_client = WebClient(
|
||||
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||
"""NOTE: not used atm. All channel access is done at the
|
||||
individual user level. Leaving in for now in case we need it later."""
|
||||
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
slack_client = SlackConnector.make_slack_web_client(
|
||||
provider.get_provider_key(),
|
||||
cc_pair.credential.credential_json["slack_bot_token"],
|
||||
SlackConnector.MAX_RETRIES,
|
||||
r,
|
||||
)
|
||||
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
|
||||
|
||||
def fetch_user_id_to_email_map(
|
||||
slack_client: WebClient,
|
||||
) -> dict[str, str]:
|
||||
user_id_to_email_map = {}
|
||||
for user_info in make_paginated_slack_api_call_w_retries(
|
||||
for user_info in make_paginated_slack_api_call(
|
||||
slack_client.users_list,
|
||||
):
|
||||
for user in user_info.get("members", []):
|
||||
|
||||
@@ -1,88 +1,188 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
from ee.onyx.external_permissions.perm_sync_types import CensoringFuncType
|
||||
from ee.onyx.external_permissions.perm_sync_types import DocSyncFuncType
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
)
|
||||
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from ee.onyx.external_permissions.slack.group_sync import slack_group_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
Generator[DocExternalAccess, None, None],
|
||||
]
|
||||
if TYPE_CHECKING:
|
||||
from onyx.access.models import DocExternalAccess # noqa
|
||||
from onyx.db.models import ConnectorCredentialPair # noqa
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
str,
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
list[ExternalUserGroup],
|
||||
]
|
||||
|
||||
# These functions update:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_doc_sync,
|
||||
DocumentSource.SLACK: slack_doc_sync,
|
||||
DocumentSource.GMAIL: gmail_doc_sync,
|
||||
}
|
||||
class DocSyncConfig(BaseModel):
|
||||
doc_sync_frequency: int
|
||||
doc_sync_func: DocSyncFuncType
|
||||
initial_index_should_sync: bool
|
||||
|
||||
# These functions update:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_group_sync,
|
||||
DocumentSource.SLACK: slack_group_sync,
|
||||
|
||||
class GroupSyncConfig(BaseModel):
|
||||
group_sync_frequency: int
|
||||
group_sync_func: GroupSyncFuncType
|
||||
group_sync_is_cc_pair_agnostic: bool
|
||||
|
||||
|
||||
class CensoringConfig(BaseModel):
|
||||
chunk_censoring_func: CensoringFuncType
|
||||
|
||||
|
||||
class SyncConfig(BaseModel):
|
||||
# None means we don't perform a doc_sync
|
||||
doc_sync_config: DocSyncConfig | None = None
|
||||
# None means we don't perform a group_sync
|
||||
group_sync_config: GroupSyncConfig | None = None
|
||||
# None means we don't perform a chunk_censoring
|
||||
censoring_config: CensoringConfig | None = None
|
||||
|
||||
|
||||
# Mock doc sync function for testing (no-op)
|
||||
def mock_doc_sync(
|
||||
cc_pair: "ConnectorCredentialPair",
|
||||
fetch_all_docs_fn: FetchAllDocumentsFunction,
|
||||
callback: Optional["IndexingHeartbeatInterface"],
|
||||
) -> Generator["DocExternalAccess", None, None]:
|
||||
"""Mock doc sync function for testing - returns empty list since permissions are fetched during indexing"""
|
||||
yield from []
|
||||
|
||||
|
||||
_SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
|
||||
DocumentSource.GOOGLE_DRIVE: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=gdrive_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
group_sync_config=GroupSyncConfig(
|
||||
group_sync_frequency=GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
group_sync_func=gdrive_group_sync,
|
||||
group_sync_is_cc_pair_agnostic=False,
|
||||
),
|
||||
),
|
||||
DocumentSource.CONFLUENCE: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=confluence_doc_sync,
|
||||
initial_index_should_sync=False,
|
||||
),
|
||||
group_sync_config=GroupSyncConfig(
|
||||
group_sync_frequency=CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
group_sync_func=confluence_group_sync,
|
||||
group_sync_is_cc_pair_agnostic=True,
|
||||
),
|
||||
),
|
||||
DocumentSource.SLACK: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=slack_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
# groups are not needed for Slack. All channel access is done at the
|
||||
# individual user level
|
||||
group_sync_config=None,
|
||||
),
|
||||
DocumentSource.GMAIL: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=gmail_doc_sync,
|
||||
initial_index_should_sync=False,
|
||||
),
|
||||
),
|
||||
DocumentSource.SALESFORCE: SyncConfig(
|
||||
censoring_config=CensoringConfig(
|
||||
chunk_censoring_func=censor_salesforce_chunks,
|
||||
),
|
||||
),
|
||||
DocumentSource.MOCK_CONNECTOR: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=mock_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
|
||||
DocumentSource.CONFLUENCE,
|
||||
}
|
||||
def source_requires_doc_sync(source: DocumentSource) -> bool:
|
||||
"""Checks if the given DocumentSource requires doc syncing."""
|
||||
if source not in _SOURCE_TO_SYNC_CONFIG:
|
||||
return False
|
||||
return _SOURCE_TO_SYNC_CONFIG[source].doc_sync_config is not None
|
||||
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
}
|
||||
def source_requires_external_group_sync(source: DocumentSource) -> bool:
|
||||
"""Checks if the given DocumentSource requires external group syncing."""
|
||||
if source not in _SOURCE_TO_SYNC_CONFIG:
|
||||
return False
|
||||
return _SOURCE_TO_SYNC_CONFIG[source].group_sync_config is not None
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 30 minutes
|
||||
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
def get_source_perm_sync_config(source: DocumentSource) -> SyncConfig | None:
|
||||
"""Returns the frequency of the external group sync for the given DocumentSource."""
|
||||
return _SOURCE_TO_SYNC_CONFIG.get(source)
|
||||
|
||||
|
||||
def source_group_sync_is_cc_pair_agnostic(source: DocumentSource) -> bool:
|
||||
"""Checks if the given DocumentSource requires external group syncing."""
|
||||
if source not in _SOURCE_TO_SYNC_CONFIG:
|
||||
return False
|
||||
|
||||
group_sync_config = _SOURCE_TO_SYNC_CONFIG[source].group_sync_config
|
||||
if group_sync_config is None:
|
||||
return False
|
||||
|
||||
return group_sync_config.group_sync_is_cc_pair_agnostic
|
||||
|
||||
|
||||
def get_all_cc_pair_agnostic_group_sync_sources() -> set[DocumentSource]:
|
||||
"""Returns the set of sources that have external group syncing that is cc_pair agnostic."""
|
||||
return {
|
||||
source
|
||||
for source, sync_config in _SOURCE_TO_SYNC_CONFIG.items()
|
||||
if sync_config.group_sync_config is not None
|
||||
and sync_config.group_sync_config.group_sync_is_cc_pair_agnostic
|
||||
}
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return (
|
||||
source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
)
|
||||
return source_type in _SOURCE_TO_SYNC_CONFIG
|
||||
|
||||
|
||||
def get_all_censoring_enabled_sources() -> set[DocumentSource]:
|
||||
"""Returns the set of sources that have censoring enabled."""
|
||||
return {
|
||||
source
|
||||
for source, sync_config in _SOURCE_TO_SYNC_CONFIG.items()
|
||||
if sync_config.censoring_config is not None
|
||||
}
|
||||
|
||||
|
||||
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:
|
||||
"""Returns True if the given DocumentSource requires permissions to be fetched during indexing."""
|
||||
if source not in _SOURCE_TO_SYNC_CONFIG:
|
||||
return False
|
||||
|
||||
doc_sync_config = _SOURCE_TO_SYNC_CONFIG[source].doc_sync_config
|
||||
if doc_sync_config is None:
|
||||
return False
|
||||
|
||||
return doc_sync_config.initial_index_should_sync
|
||||
|
||||
@@ -10,6 +10,7 @@ from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
|
||||
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.documents.cc_pair import router as ee_document_cc_pair_router
|
||||
from ee.onyx.server.enterprise_settings.api import (
|
||||
admin_router as enterprise_settings_admin_router,
|
||||
)
|
||||
@@ -50,6 +51,7 @@ from onyx.main import get_application as get_application_base
|
||||
from onyx.main import include_auth_router_with_prefix
|
||||
from onyx.main import include_router_with_global_prefix_prepended
|
||||
from onyx.main import lifespan as lifespan_base
|
||||
from onyx.main import use_route_function_names_as_operation_ids
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -167,6 +169,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
@@ -190,4 +193,6 @@ def get_application() -> FastAPI:
|
||||
# for route in application.router.routes:
|
||||
# print(f"Path: {route.path}, Methods: {route.methods}")
|
||||
|
||||
use_route_function_names_as_operation_ids(application)
|
||||
|
||||
return application
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from ee.onyx.db.standard_answer import find_matching_standard_answers
|
||||
from ee.onyx.server.manage.models import StandardAnswer as PydanticStandardAnswer
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_REACT_EMOJI
|
||||
from onyx.db.chat import create_chat_session
|
||||
@@ -24,6 +23,7 @@ from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.server.manage.models import StandardAnswer as PydanticStandardAnswer
|
||||
from onyx.utils.logger import OnyxLoggingAdapter
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
177
backend/ee/onyx/server/documents/cc_pair.py
Normal file
177
backend/ee/onyx/server/documents/cc_pair.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
try_creating_permissions_sync_task,
|
||||
)
|
||||
from ee.onyx.background.celery.tasks.external_group_syncing.tasks import (
|
||||
try_creating_external_group_sync_task,
|
||||
)
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pair_from_id_for_user,
|
||||
)
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync-permissions")
|
||||
def get_cc_pair_latest_sync(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> datetime | None:
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="cc_pair not found for current user's permissions",
|
||||
)
|
||||
|
||||
return cc_pair.last_time_perm_sync
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync-permissions")
|
||||
def sync_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[None]:
|
||||
"""Triggers permissions sync on a particular cc_pair immediately"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.permissions.fenced:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Permissions sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Permissions sync cc_pair={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
client_app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="Permissions sync task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}")
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the permissions sync task.",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync-groups")
|
||||
def get_cc_pair_latest_group_sync(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> datetime | None:
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="cc_pair not found for current user's permissions",
|
||||
)
|
||||
|
||||
return cc_pair.last_time_external_group_sync
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync-groups")
|
||||
def sync_cc_pair_groups(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[None]:
|
||||
"""Triggers group sync on a particular cc_pair immediately"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="External group sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"External group sync cc_pair={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_external_group_sync_task(
|
||||
client_app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="External group sync task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}")
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the external group sync task.",
|
||||
)
|
||||
@@ -29,7 +29,11 @@ from onyx.auth.users import UserManager
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.file_store import PostgresBackedFileStore
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/enterprise-settings")
|
||||
basic_router = APIRouter(prefix="/enterprise-settings")
|
||||
@@ -110,14 +114,19 @@ async def refresh_access_token(
|
||||
|
||||
|
||||
@admin_router.put("")
|
||||
def put_settings(
|
||||
def admin_ee_put_settings(
|
||||
settings: EnterpriseSettings, _: User | None = Depends(current_admin_user)
|
||||
) -> None:
|
||||
store_settings(settings)
|
||||
|
||||
|
||||
@basic_router.get("")
|
||||
def fetch_settings() -> EnterpriseSettings:
|
||||
def ee_fetch_settings() -> EnterpriseSettings:
|
||||
if MULTI_TENANT:
|
||||
tenant_id = get_current_tenant_id()
|
||||
if not tenant_id or tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
raise BasicAuthenticationError(detail="User must authenticate")
|
||||
|
||||
return load_settings()
|
||||
|
||||
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.db.models import StandardAnswer as StandardAnswerModel
|
||||
from onyx.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
|
||||
|
||||
class StandardAnswerCategoryCreationRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class StandardAnswerCategory(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, standard_answer_category: StandardAnswerCategoryModel
|
||||
) -> "StandardAnswerCategory":
|
||||
return cls(
|
||||
id=standard_answer_category.id,
|
||||
name=standard_answer_category.name,
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(BaseModel):
|
||||
id: int
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
return cls(
|
||||
id=standard_answer_model.id,
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
match_regex=standard_answer_model.match_regex,
|
||||
match_any_keywords=standard_answer_model.match_any_keywords,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCreationRequest(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[int]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
def validate_categories(cls, value: list[int]) -> list[int]:
|
||||
if len(value) < 1:
|
||||
raise ValueError(
|
||||
"At least one category must be attached to a standard answer"
|
||||
)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_only_match_any_if_not_regex(self) -> Any:
|
||||
if self.match_regex and self.match_any_keywords:
|
||||
raise ValueError(
|
||||
"Can only match any keywords in keyword mode, not regex mode"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_keyword_if_regex(self) -> Any:
|
||||
if not self.match_regex:
|
||||
# no validation for keywords
|
||||
return self
|
||||
|
||||
try:
|
||||
re.compile(self.keyword)
|
||||
return self
|
||||
except re.error as err:
|
||||
if isinstance(err.pattern, bytes):
|
||||
raise ValueError(
|
||||
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
|
||||
)
|
||||
else:
|
||||
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
|
||||
raise ValueError(
|
||||
" ".join(
|
||||
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
|
||||
)
|
||||
)
|
||||
@@ -12,13 +12,13 @@ from ee.onyx.db.standard_answer import insert_standard_answer_category
|
||||
from ee.onyx.db.standard_answer import remove_standard_answer
|
||||
from ee.onyx.db.standard_answer import update_standard_answer
|
||||
from ee.onyx.db.standard_answer import update_standard_answer_category
|
||||
from ee.onyx.server.manage.models import StandardAnswer
|
||||
from ee.onyx.server.manage.models import StandardAnswerCategory
|
||||
from ee.onyx.server.manage.models import StandardAnswerCategoryCreationRequest
|
||||
from ee.onyx.server.manage.models import StandardAnswerCreationRequest
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
from onyx.server.manage.models import StandardAnswerCategory
|
||||
from onyx.server.manage.models import StandardAnswerCategoryCreationRequest
|
||||
from onyx.server.manage.models import StandardAnswerCreationRequest
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.auth.users import decode_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.auth.api_key import extract_tenant_from_api_key_header
|
||||
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.engine import is_valid_schema_name
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
|
||||
@@ -14,11 +14,11 @@ from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
|
||||
|
||||
@@ -11,11 +11,11 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
|
||||
@@ -9,11 +9,11 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.credentials import create_credential
|
||||
|
||||
@@ -43,7 +43,6 @@ from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
@@ -339,10 +338,7 @@ def handle_send_message_simple_with_history(
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name, model_provider=llm.config.model_provider
|
||||
)
|
||||
max_history_tokens = int(input_tokens * CHAT_TARGET_CHUNK_PERCENTAGE)
|
||||
max_history_tokens = int(llm.config.max_input_tokens * CHAT_TARGET_CHUNK_PERCENTAGE)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
|
||||
@@ -6,7 +6,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from ee.onyx.server.manage.models import StandardAnswer
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
@@ -19,6 +18,7 @@ from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
|
||||
@@ -38,7 +38,6 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -177,10 +176,9 @@ def get_answer_stream(
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name, model_provider=llm.config.model_provider
|
||||
max_history_tokens = int(
|
||||
llm.config.max_input_tokens * MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
)
|
||||
max_history_tokens = int(input_tokens * MAX_THREAD_CONTEXT_PERCENTAGE)
|
||||
|
||||
combined_message = combine_message_thread(
|
||||
messages=query_request.messages,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import csv
|
||||
import io
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
@@ -12,70 +12,108 @@ from fastapi import Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
|
||||
from ee.onyx.background.task_name_builders import query_history_task_name
|
||||
from ee.onyx.db.query_history import get_all_query_history_export_tasks
|
||||
from ee.onyx.db.query_history import get_page_of_chat_sessions
|
||||
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from ee.onyx.server.query_history.models import MessageSnapshot
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from ee.onyx.server.query_history.models import QueryHistoryExport
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pg_file_store import get_query_history_export_files
|
||||
from onyx.db.tasks import get_task_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
|
||||
|
||||
|
||||
def ensure_query_history_is_enabled(
|
||||
disallowed: list[QueryHistoryType],
|
||||
) -> None:
|
||||
if ONYX_QUERY_HISTORY_TYPE in disallowed:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
|
||||
def yield_snapshot_from_chat_session(
|
||||
chat_session: ChatSession,
|
||||
db_session: Session,
|
||||
) -> Generator[ChatSessionSnapshot | None]:
|
||||
yield snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
feedback_type: QAFeedbackType | None,
|
||||
limit: int | None = 500,
|
||||
) -> list[ChatSessionSnapshot]:
|
||||
# observed to be slow a scale of 8192 sessions and 4 messages per session
|
||||
) -> Generator[ChatSessionSnapshot]:
|
||||
PAGE_SIZE = 100
|
||||
|
||||
# this is a little slow (5 seconds)
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=start, end=end, db_session=db_session, limit=limit
|
||||
)
|
||||
page = 0
|
||||
while True:
|
||||
paged_chat_sessions = get_page_of_chat_sessions(
|
||||
start_time=start,
|
||||
end_time=end,
|
||||
db_session=db_session,
|
||||
page_num=page,
|
||||
page_size=PAGE_SIZE,
|
||||
)
|
||||
|
||||
# this is VERY slow (80 seconds) due to create_chat_chain being called
|
||||
# for each session. Needs optimizing.
|
||||
chat_session_snapshots = [
|
||||
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
|
||||
for chat_session in chat_sessions
|
||||
]
|
||||
if not paged_chat_sessions:
|
||||
break
|
||||
|
||||
valid_snapshots = [
|
||||
snapshot for snapshot in chat_session_snapshots if snapshot is not None
|
||||
]
|
||||
paged_snapshots = parallel_yield(
|
||||
[
|
||||
yield_snapshot_from_chat_session(
|
||||
db_session=db_session,
|
||||
chat_session=chat_session,
|
||||
)
|
||||
for chat_session in paged_chat_sessions
|
||||
]
|
||||
)
|
||||
|
||||
if feedback_type:
|
||||
valid_snapshots = [
|
||||
snapshot
|
||||
for snapshot in valid_snapshots
|
||||
if any(
|
||||
message.feedback_type == feedback_type for message in snapshot.messages
|
||||
)
|
||||
]
|
||||
for snapshot in paged_snapshots:
|
||||
if snapshot:
|
||||
yield snapshot
|
||||
|
||||
return valid_snapshots
|
||||
# If we've fetched *less* than a `PAGE_SIZE` worth
|
||||
# of data, we have reached the end of the
|
||||
# pagination sequence; break.
|
||||
if len(paged_chat_sessions) < PAGE_SIZE:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
|
||||
def snapshot_from_chat_session(
|
||||
@@ -112,21 +150,19 @@ def snapshot_from_chat_session(
|
||||
|
||||
|
||||
@router.get("/admin/chat-sessions")
|
||||
def get_user_chat_sessions(
|
||||
def admin_get_chat_sessions(
|
||||
user_id: UUID,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
# this is a direct query on the user id
|
||||
if ONYX_QUERY_HISTORY_TYPE in [
|
||||
QueryHistoryType.DISABLED,
|
||||
QueryHistoryType.ANONYMIZED,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Per user query history has been disabled by the administrator.",
|
||||
)
|
||||
ensure_query_history_is_enabled(
|
||||
[
|
||||
QueryHistoryType.DISABLED,
|
||||
QueryHistoryType.ANONYMIZED,
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
@@ -163,11 +199,7 @@ def get_chat_session_history(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
@@ -205,11 +237,7 @@ def get_chat_session_admin(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
@@ -220,7 +248,8 @@ def get_chat_session_admin(
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
400, f"Chat session with id '{chat_session_id}' does not exist."
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
f"Chat session with id '{chat_session_id}' does not exist.",
|
||||
)
|
||||
snapshot = snapshot_from_chat_session(
|
||||
chat_session=chat_session, db_session=db_session
|
||||
@@ -228,7 +257,7 @@ def get_chat_session_admin(
|
||||
|
||||
if snapshot is None:
|
||||
raise HTTPException(
|
||||
400,
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
f"Could not create snapshot for chat session with id '{chat_session_id}'",
|
||||
)
|
||||
|
||||
@@ -238,52 +267,165 @@ def get_chat_session_admin(
|
||||
return snapshot
|
||||
|
||||
|
||||
@router.get("/admin/query-history-csv")
|
||||
def get_query_history_as_csv(
|
||||
@router.get("/admin/query-history/list")
|
||||
def list_all_query_history_exports(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[QueryHistoryExport]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
try:
|
||||
pending_tasks = [
|
||||
QueryHistoryExport.from_task(task)
|
||||
for task in get_all_query_history_export_tasks(db_session=db_session)
|
||||
]
|
||||
generated_files = [
|
||||
QueryHistoryExport.from_file(file)
|
||||
for file in get_query_history_export_files(db_session=db_session)
|
||||
]
|
||||
merged = pending_tasks + generated_files
|
||||
|
||||
# We sort based off of the start-time of the task.
|
||||
# We also return it in reverse order since viewing generated reports in most-recent to least-recent is most common.
|
||||
merged.sort(key=lambda task: task.start_time, reverse=True)
|
||||
|
||||
return merged
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR, f"Failed to get all tasks: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/query-history/start-export")
|
||||
def start_query_history_export(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> dict[str, str]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
start = start or datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
end = end or datetime.now(tz=timezone.utc)
|
||||
|
||||
if start >= end:
|
||||
raise HTTPException(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
f"Start time must come before end time, but instead got the start time coming after; {start=} {end=}",
|
||||
)
|
||||
|
||||
task_id_uuid = uuid.uuid4()
|
||||
task_id = str(task_id_uuid)
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=query_history_task_name(start=start, end=end),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.PENDING,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
|
||||
task_id=task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
queue=OnyxCeleryQueues.CSV_GENERATION,
|
||||
kwargs={
|
||||
"start": start,
|
||||
"end": end,
|
||||
"start_time": start_time,
|
||||
},
|
||||
)
|
||||
|
||||
return {"request_id": task_id}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/export-status")
|
||||
def get_query_history_export_status(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
task = get_task_with_id(db_session=db_session, task_id=request_id)
|
||||
|
||||
if task:
|
||||
return {"status": task.status}
|
||||
|
||||
# If task is None, then it's possible that the task has already finished processing.
|
||||
# Therefore, we should then check if the export file has already been stored inside of the file-store.
|
||||
# If that *also* doesn't exist, then we can return a 404.
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
report_name = construct_query_history_report_name(request_id)
|
||||
has_file = file_store.has_file(
|
||||
file_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
file_type=FileType.CSV,
|
||||
)
|
||||
|
||||
if not has_file:
|
||||
raise HTTPException(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"No task with {request_id=} was found",
|
||||
)
|
||||
|
||||
return {"status": TaskStatus.SUCCESS}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/download")
|
||||
def download_query_history_csv(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
report_name = construct_query_history_report_name(request_id)
|
||||
file_store = get_default_file_store(db_session)
|
||||
has_file = file_store.has_file(
|
||||
file_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
file_type=FileType.CSV,
|
||||
)
|
||||
|
||||
if has_file:
|
||||
try:
|
||||
csv_stream = file_store.read_file(report_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
f"Failed to read query history file: {str(e)}",
|
||||
)
|
||||
csv_stream.seek(0)
|
||||
return StreamingResponse(
|
||||
iter(csv_stream),
|
||||
media_type=FileType.CSV,
|
||||
headers={"Content-Disposition": f"attachment;filename={report_name}"},
|
||||
)
|
||||
|
||||
# If the file doesn't exist yet, it may still be processing.
|
||||
# Therefore, we check the task queue to determine its status, if there is any.
|
||||
task = get_task_with_id(db_session=db_session, task_id=request_id)
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"No task with {request_id=} was found",
|
||||
)
|
||||
|
||||
# this call is very expensive and is timing out via endpoint
|
||||
# TODO: optimize call and/or generate via background task
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
end=end or datetime.now(tz=timezone.utc),
|
||||
feedback_type=None,
|
||||
limit=None,
|
||||
)
|
||||
|
||||
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
|
||||
for chat_session_snapshot in complete_chat_session_history:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
question_answer_pairs.extend(
|
||||
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
|
||||
if task.status in [TaskStatus.STARTED, TaskStatus.PENDING]:
|
||||
raise HTTPException(
|
||||
HTTPStatus.ACCEPTED, f"Task with {request_id=} is still being worked on"
|
||||
)
|
||||
|
||||
# Create an in-memory text stream
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream, fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys())
|
||||
)
|
||||
writer.writeheader()
|
||||
for row in question_answer_pairs:
|
||||
writer.writerow(row.to_json())
|
||||
|
||||
# Reset the stream's position to the start
|
||||
stream.seek(0)
|
||||
|
||||
return StreamingResponse(
|
||||
iter([stream.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment;filename=onyx_query_history.csv"},
|
||||
)
|
||||
elif task.status == TaskStatus.FAILURE:
|
||||
raise HTTPException(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
f"Task with {request_id=} failed to be processed",
|
||||
)
|
||||
else:
|
||||
# This is the final case in which `task.status == SUCCESS`
|
||||
raise RuntimeError(
|
||||
"The task was marked as success, the file was not found in the file store; this is an internal error..."
|
||||
)
|
||||
|
||||
@@ -3,12 +3,17 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.background.task_name_builders import QUERY_HISTORY_TASK_NAME_PREFIX
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.background.task_utils import extract_task_id_from_query_history_report_name
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import PGFileStore
|
||||
from onyx.db.models import TaskQueueState
|
||||
|
||||
|
||||
class AbridgedSearchDoc(BaseModel):
|
||||
@@ -216,3 +221,59 @@ class QuestionAnswerPairSnapshot(BaseModel):
|
||||
"time_created": str(self.time_created),
|
||||
"flow_type": self.flow_type,
|
||||
}
|
||||
|
||||
|
||||
class QueryHistoryExport(BaseModel):
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
start: datetime
|
||||
end: datetime
|
||||
start_time: datetime
|
||||
|
||||
@classmethod
|
||||
def from_task(
|
||||
cls,
|
||||
task_queue_state: TaskQueueState,
|
||||
) -> "QueryHistoryExport":
|
||||
start_end = task_queue_state.task_name.removeprefix(
|
||||
f"{QUERY_HISTORY_TASK_NAME_PREFIX}_"
|
||||
)
|
||||
start, end = start_end.split("_")
|
||||
|
||||
if not task_queue_state.start_time:
|
||||
raise RuntimeError("The start time of the task must always be present")
|
||||
|
||||
return cls(
|
||||
task_id=task_queue_state.task_id,
|
||||
status=task_queue_state.status,
|
||||
start=datetime.fromisoformat(start),
|
||||
end=datetime.fromisoformat(end),
|
||||
start_time=task_queue_state.start_time,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls,
|
||||
file: PGFileStore,
|
||||
) -> "QueryHistoryExport":
|
||||
if not file.file_metadata or not isinstance(file.file_metadata, dict):
|
||||
raise RuntimeError(
|
||||
"The file metadata must be non-null, and must be of type `dict[str, str]`"
|
||||
)
|
||||
|
||||
metadata = QueryHistoryFileMetadata.model_validate(dict(file.file_metadata))
|
||||
task_id = extract_task_id_from_query_history_report_name(file.file_name)
|
||||
|
||||
return cls(
|
||||
task_id=task_id,
|
||||
status=TaskStatus.SUCCESS,
|
||||
start=metadata.start,
|
||||
end=metadata.end,
|
||||
start_time=metadata.start_time,
|
||||
)
|
||||
|
||||
|
||||
class QueryHistoryFileMetadata(BaseModel):
|
||||
start: datetime
|
||||
end: datetime
|
||||
start_time: datetime
|
||||
|
||||
@@ -7,6 +7,8 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from fastapi_users_db_sqlalchemy import UUID_ID
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.usage_export import get_all_empty_chat_message_entries
|
||||
@@ -14,6 +16,7 @@ from ee.onyx.db.usage_export import write_usage_report
|
||||
from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata
|
||||
from ee.onyx.server.reporting.usage_export_models import UserSkeleton
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_all_users
|
||||
from onyx.file_store.constants import MAX_IN_MEMORY_SIZE
|
||||
from onyx.file_store.file_store import FileStore
|
||||
@@ -153,11 +156,19 @@ def create_new_usage_report(
|
||||
# add report after zip file is written
|
||||
new_report = write_usage_report(db_session, report_name, user_id, period)
|
||||
|
||||
# get user email
|
||||
requestor_user = (
|
||||
db_session.query(User)
|
||||
.filter(cast(User.id, UUID) == new_report.requestor_user_id)
|
||||
.one_or_none()
|
||||
if new_report.requestor_user_id
|
||||
else None
|
||||
)
|
||||
requestor_email = requestor_user.email if requestor_user else None
|
||||
|
||||
return UsageReportMetadata(
|
||||
report_name=new_report.report_name,
|
||||
requestor=(
|
||||
str(new_report.requestor_user_id) if new_report.requestor_user_id else None
|
||||
),
|
||||
requestor=requestor_email,
|
||||
time_created=new_report.time_created,
|
||||
period_from=new_report.period_from,
|
||||
period_to=new_report.period_to,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -9,7 +10,6 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi_users import exceptions
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -28,6 +28,7 @@ from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_context_manager
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -38,14 +39,21 @@ router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
"""
|
||||
Creates or updates a user account for SAML authentication.
|
||||
|
||||
For new users or users with non-web-login roles:
|
||||
1. Generates a secure random password that meets validation criteria
|
||||
2. Creates the user with appropriate role and verified status
|
||||
|
||||
SAML users never use this password directly as they authenticate via their
|
||||
Identity Provider, but we need a valid password to satisfy system requirements.
|
||||
"""
|
||||
logger.debug(f"Attempting to upsert SAML user with email: {email}")
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
) # type:ignore
|
||||
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
|
||||
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
|
||||
|
||||
async with get_async_session_context() as session:
|
||||
async with get_async_session_context_manager() as session:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
try:
|
||||
@@ -60,15 +68,41 @@ async def upsert_saml_user(email: str) -> User:
|
||||
user_count = await get_user_count()
|
||||
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
|
||||
|
||||
fastapi_users_pw_helper = PasswordHelper()
|
||||
password = fastapi_users_pw_helper.generate()
|
||||
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||
# Generate a secure random password meeting validation requirements
|
||||
# We use a secure random password since we never need to know what it is
|
||||
# (SAML users authenticate via their IdP)
|
||||
secure_random_password = "".join(
|
||||
[
|
||||
# Ensure minimum requirements are met
|
||||
secrets.choice(
|
||||
string.ascii_uppercase
|
||||
), # at least one uppercase
|
||||
secrets.choice(
|
||||
string.ascii_lowercase
|
||||
), # at least one lowercase
|
||||
secrets.choice(string.digits), # at least one digit
|
||||
secrets.choice(
|
||||
"!@#$%^&*()-_=+[]{}|;:,.<>?"
|
||||
), # at least one special
|
||||
# Fill remaining length with random chars (mix of all types)
|
||||
"".join(
|
||||
secrets.choice(
|
||||
string.ascii_letters
|
||||
+ string.digits
|
||||
+ "!@#$%^&*()-_=+[]{}|;:,.<>?"
|
||||
)
|
||||
for _ in range(12)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Create the user with SAML-appropriate settings
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=hashed_pass,
|
||||
password=secure_random_password, # Pass raw password, not hash
|
||||
role=role,
|
||||
is_verified=True, # SAML users are pre-verified by their IdP
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from fastapi_users import exceptions
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
@@ -24,14 +25,24 @@ async def impersonate_user(
|
||||
_: User = Depends(current_cloud_superuser),
|
||||
) -> Response:
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
except exceptions.UserNotExists:
|
||||
detail = f"User has no tenant mapping: {impersonate_request.email=}"
|
||||
logger.warning(detail)
|
||||
raise HTTPException(status_code=422, detail=detail)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
detail = (
|
||||
f"User not found in tenant: {impersonate_request.email=} {tenant_id=}"
|
||||
)
|
||||
logger.warning(detail)
|
||||
raise HTTPException(status_code=422, detail=detail)
|
||||
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
|
||||
@@ -5,7 +5,6 @@ from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
@@ -17,6 +16,7 @@ from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -7,7 +7,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.setup import setup_logger
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -39,10 +39,13 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_VISIBLE_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_VISIBLE_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -269,8 +272,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=name,
|
||||
is_visible=name in ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for name in ANTHROPIC_MODEL_NAMES
|
||||
],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
@@ -290,8 +299,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name,
|
||||
is_visible=model_name in OPEN_AI_VISIBLE_MODEL_NAMES,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for model_name in OPEN_AI_MODEL_NAMES
|
||||
],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
@@ -406,7 +421,6 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
headers=headers,
|
||||
json=payload.model_dump(),
|
||||
) as response:
|
||||
print(response)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
|
||||
@@ -9,7 +9,7 @@ from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.server.manage.models import TenantSnapshot
|
||||
from onyx.setup import setup_logger
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -47,10 +47,10 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
tenant_id = mapping.tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting tenant id for email {email}: {e}")
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
2
backend/generated/README.md
Normal file
2
backend/generated/README.md
Normal file
@@ -0,0 +1,2 @@
|
||||
- Generated Files
|
||||
* Generated files live here. This directory should be git ignored.
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -39,10 +41,10 @@ logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
||||
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
@@ -50,13 +52,14 @@ _INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
|
||||
"distilbert-base-uncased"
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER = cast(
|
||||
PreTrainedTokenizer,
|
||||
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
)
|
||||
return _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
|
||||
@@ -92,12 +95,15 @@ def get_local_connector_classifier(
|
||||
return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
def get_intent_model_tokenizer() -> AutoTokenizer:
|
||||
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
_INTENT_TOKENIZER = cast(
|
||||
PreTrainedTokenizer,
|
||||
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
)
|
||||
return _INTENT_TOKENIZER
|
||||
|
||||
|
||||
@@ -395,9 +401,9 @@ def run_content_classification_inference(
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
|
||||
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||
|
||||
if not len(tokens) == len(is_keyword):
|
||||
raise ValueError("Length of tokens and keyword predictions must match")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user