mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-20 01:05:46 +00:00
Compare commits
267 Commits
colours
...
projects-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec85100c29 | ||
|
|
46667f9f55 | ||
|
|
a51f193211 | ||
|
|
18ef4cc095 | ||
|
|
0e4a0578de | ||
|
|
d4d0106d8f | ||
|
|
5c7417fa97 | ||
|
|
eaa14a5ce0 | ||
|
|
b07c834e83 | ||
|
|
97cd308ef7 | ||
|
|
28cdab7a70 | ||
|
|
ad9aa01819 | ||
|
|
185e57d55e | ||
|
|
e4777338b8 | ||
|
|
9e5fea5395 | ||
|
|
eb39ba9be4 | ||
|
|
34a48b1a15 | ||
|
|
e8141db66d | ||
|
|
7027b3e385 | ||
|
|
dd123bb1af | ||
|
|
508a88c8d7 | ||
|
|
013648e205 | ||
|
|
b6f81fbb8e | ||
|
|
e313fd7431 | ||
|
|
fd2cef73ba | ||
|
|
1203ae174f | ||
|
|
11d784c0a7 | ||
|
|
73bb1c5445 | ||
|
|
4e7ee20406 | ||
|
|
c44d3c0332 | ||
|
|
b9b66396ec | ||
|
|
dd20b9ef4c | ||
|
|
d599935d4a | ||
|
|
e1f7e8cacf | ||
|
|
c56575d6f9 | ||
|
|
fd567279fd | ||
|
|
1427eb3cf0 | ||
|
|
e70be0f816 | ||
|
|
0014c7cff7 | ||
|
|
44c30b45bd | ||
|
|
ca4908f2d1 | ||
|
|
1c23dbeaee | ||
|
|
e417da31c3 | ||
|
|
944bff8a45 | ||
|
|
885a484900 | ||
|
|
b2b122a24b | ||
|
|
d21eaa92e4 | ||
|
|
033ae74b0e | ||
|
|
c593fb4866 | ||
|
|
b9580ef346 | ||
|
|
4df3a9204f | ||
|
|
e0ad313a60 | ||
|
|
a2bfb46edd | ||
|
|
423961878e | ||
|
|
49fdf2cc78 | ||
|
|
4f3d1466fb | ||
|
|
0d78fefaa4 | ||
|
|
192cfd6965 | ||
|
|
25e3371bee | ||
|
|
4b9b306140 | ||
|
|
c0858484a4 | ||
|
|
873730cb02 | ||
|
|
acecca0de5 | ||
|
|
ccf55136be | ||
|
|
a13db828f3 | ||
|
|
9aed84ec66 | ||
|
|
b7d56d0645 | ||
|
|
9ac70d35a8 | ||
|
|
5e28d86a18 | ||
|
|
313e38cf3e | ||
|
|
f3dcb31b6c | ||
|
|
b22e16b604 | ||
|
|
43c62cfe5b | ||
|
|
196945378f | ||
|
|
851a14ce68 | ||
|
|
ac7d1e358f | ||
|
|
b1124fd042 | ||
|
|
92ef44972d | ||
|
|
7da792dd27 | ||
|
|
e00d175b06 | ||
|
|
289cc2d83a | ||
|
|
c38078ef67 | ||
|
|
e08ec11b4f | ||
|
|
8fbf06fcdb | ||
|
|
136c2f4082 | ||
|
|
f14e450f1e | ||
|
|
67bd14e801 | ||
|
|
e388c26a60 | ||
|
|
8c2514ccbd | ||
|
|
8c9a20be7a | ||
|
|
0427845502 | ||
|
|
0c95af31c7 | ||
|
|
656e1fe7cf | ||
|
|
6461638b65 | ||
|
|
157ae8e18e | ||
|
|
ef0e7984ca | ||
|
|
74455041be | ||
|
|
769c24272f | ||
|
|
d718dd485d | ||
|
|
7b533ef535 | ||
|
|
3410e5b59b | ||
|
|
b4e453f3d1 | ||
|
|
9d62d83c5c | ||
|
|
cf33d1ebb9 | ||
|
|
9604ddb089 | ||
|
|
73c7cb1aed | ||
|
|
2fb5ef3a9b | ||
|
|
cd8bb439bb | ||
|
|
8f2107f61e | ||
|
|
3b42f0556a | ||
|
|
59db35cf2d | ||
|
|
e6479410c4 | ||
|
|
a4c271b5ca | ||
|
|
70dad196ca | ||
|
|
ed6272dac6 | ||
|
|
4fa1050d4f | ||
|
|
dfdb94269b | ||
|
|
692766d1f7 | ||
|
|
0fa48521de | ||
|
|
1166697599 | ||
|
|
a831c54a85 | ||
|
|
1bbf7211ba | ||
|
|
438b762360 | ||
|
|
fb54e41337 | ||
|
|
5427a7d766 | ||
|
|
7d822f6ee9 | ||
|
|
ac98043bad | ||
|
|
f46a27cf22 | ||
|
|
a85a5a324e | ||
|
|
78f1fb5bf4 | ||
|
|
6a8a214324 | ||
|
|
18535d58d4 | ||
|
|
1707e41683 | ||
|
|
7e25322bce | ||
|
|
884266c009 | ||
|
|
508b9076a7 | ||
|
|
4df416e482 | ||
|
|
80444f6bbc | ||
|
|
2c422215e6 | ||
|
|
23f335f033 | ||
|
|
265eca2195 | ||
|
|
32fe185bb4 | ||
|
|
2587e5bfb2 | ||
|
|
936500ca8b | ||
|
|
bdc6ddea1d | ||
|
|
d1a739c6d4 | ||
|
|
ab28f67386 | ||
|
|
ecfada63bb | ||
|
|
17a1d3b234 | ||
|
|
8e0dd12ab3 | ||
|
|
a9eb256e6d | ||
|
|
c2758a28d5 | ||
|
|
17c5d1b740 | ||
|
|
e533e98f9b | ||
|
|
1e3cbc1856 | ||
|
|
779397d9b8 | ||
|
|
5cda2e0173 | ||
|
|
2ac0133b0b | ||
|
|
9b88e778e1 | ||
|
|
7199bb980a | ||
|
|
238518af72 | ||
|
|
c11b78cfd1 | ||
|
|
9e885a68b3 | ||
|
|
996cc7265c | ||
|
|
103cea9edf | ||
|
|
27a745413d | ||
|
|
d9a56b3bd5 | ||
|
|
376fc86b0c | ||
|
|
66a779990a | ||
|
|
14da796a88 | ||
|
|
81f73ab388 | ||
|
|
10e153b420 | ||
|
|
da8f0ff589 | ||
|
|
2eb1444d80 | ||
|
|
8c76194cf6 | ||
|
|
1f9e5e3ac9 | ||
|
|
cf63c61b33 | ||
|
|
446440aec0 | ||
|
|
bd6ebe4718 | ||
|
|
691d63bc0f | ||
|
|
dfd4d9abef | ||
|
|
2b3b9b82c2 | ||
|
|
621b3e7819 | ||
|
|
4cb39bc150 | ||
|
|
690734029f | ||
|
|
897615da71 | ||
|
|
7086afaf6e | ||
|
|
f02cb76e1d | ||
|
|
8374fcef63 | ||
|
|
5b06d0355b | ||
|
|
3e27df819e | ||
|
|
9e8ab9e3dc | ||
|
|
6674cdd516 | ||
|
|
b251ea795e | ||
|
|
dcd3f009ee | ||
|
|
4e357478e0 | ||
|
|
277065181f | ||
|
|
b5b1b3287c | ||
|
|
2f58a972eb | ||
|
|
7b881dd9a4 | ||
|
|
56cd0e6725 | ||
|
|
a223dc7aea | ||
|
|
c8cc9ee590 | ||
|
|
e7290385bd | ||
|
|
8df45b5950 | ||
|
|
2b7d361c73 | ||
|
|
6b39d8eed9 | ||
|
|
f81c34d040 | ||
|
|
0771b1f476 | ||
|
|
eedd2ba3fe | ||
|
|
98554e5025 | ||
|
|
dcd2cad6b4 | ||
|
|
189f4bb071 | ||
|
|
7eeab8fb80 | ||
|
|
60f83dd0db | ||
|
|
2618602fd6 | ||
|
|
b80f96de85 | ||
|
|
74a15b2c01 | ||
|
|
408b80ce51 | ||
|
|
e82b68c1b0 | ||
|
|
af5eec648b | ||
|
|
d186c5e82e | ||
|
|
4420a50aed | ||
|
|
9caa6ea7ff | ||
|
|
8d7b217d33 | ||
|
|
57908769f1 | ||
|
|
600cec7c89 | ||
|
|
bb8ea536c4 | ||
|
|
f97869b91e | ||
|
|
aa5be56884 | ||
|
|
7580178c95 | ||
|
|
2e0bc8caf0 | ||
|
|
f9bd03c7f0 | ||
|
|
77466e1f2b | ||
|
|
8dd79345ed | ||
|
|
a049835c49 | ||
|
|
d186d8e8ed | ||
|
|
082897eb9b | ||
|
|
e38f79dec5 | ||
|
|
26e7bba25d | ||
|
|
3cde4ef77f | ||
|
|
f4d135d710 | ||
|
|
6094f70ac8 | ||
|
|
a90e58b39b | ||
|
|
e82e3141ed | ||
|
|
f8e9060bab | ||
|
|
24831fa1a1 | ||
|
|
f6a0e69b2a | ||
|
|
0394eaea7f | ||
|
|
898b8c316e | ||
|
|
4b0c6d1e54 | ||
|
|
da7dc33afa | ||
|
|
c558732ddd | ||
|
|
339ad9189b | ||
|
|
32d5e408b8 | ||
|
|
14ead457d9 | ||
|
|
458cd7e832 | ||
|
|
770a2692e9 | ||
|
|
5dd99b6acf | ||
|
|
6c7eb89374 | ||
|
|
fd11c16c6d | ||
|
|
11ec603c37 | ||
|
|
495d4cac44 | ||
|
|
fd2d74ae2e | ||
|
|
4c7a2e486b | ||
|
|
01e0ba6270 | ||
|
|
227dfc4a05 |
@@ -8,9 +8,9 @@ on:
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
@@ -33,7 +33,16 @@ jobs:
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -46,7 +55,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -119,7 +129,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
@@ -11,8 +11,8 @@ env:
|
||||
BUILDKIT_PROGRESS: plain
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -145,6 +145,15 @@ jobs:
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
@@ -157,11 +166,16 @@ jobs:
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
|
||||
if [[ "${{ steps.check_version.outputs.is_stable }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
if [[ "${{ env.EDGE_TAG }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:edge \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@v3
|
||||
|
||||
@@ -7,7 +7,10 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
DEPLOYMENT: standalone
|
||||
|
||||
jobs:
|
||||
@@ -45,6 +48,15 @@ jobs:
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -57,7 +69,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -126,7 +139,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
6
.github/workflows/helm-chart-releases.yml
vendored
6
.github/workflows/helm-chart-releases.yml
vendored
@@ -25,9 +25,11 @@ jobs:
|
||||
|
||||
- name: Add required Helm repositories
|
||||
run: |
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add keda https://kedacore.github.io/charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
|
||||
124
.github/workflows/pr-backport-autotrigger.yml
vendored
124
.github/workflows/pr-backport-autotrigger.yml
vendored
@@ -1,124 +0,0 @@
|
||||
name: Backport on Merge
|
||||
|
||||
# Note this workflow does not trigger the builds, be sure to manually tag the branches to trigger the builds
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed] # Later we check for merge so only PRs that go in can get backported
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
actions: write
|
||||
|
||||
jobs:
|
||||
backport:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.YUHONG_GH_ACTIONS }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@onyx.app"
|
||||
git fetch --prune
|
||||
|
||||
- name: Check for Backport Checkbox
|
||||
id: checkbox-check
|
||||
run: |
|
||||
PR_BODY="${{ github.event.pull_request.body }}"
|
||||
if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then
|
||||
echo "backport=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "backport=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: List and sort release branches
|
||||
id: list-branches
|
||||
run: |
|
||||
git fetch --all --tags
|
||||
BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr)
|
||||
BETA=$(echo "$BRANCHES" | head -n 1)
|
||||
STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1)
|
||||
echo "beta=release/$BETA" >> $GITHUB_OUTPUT
|
||||
echo "stable=release/$STABLE" >> $GITHUB_OUTPUT
|
||||
# Fetch latest tags for beta and stable
|
||||
LATEST_BETA_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*-beta.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$" | grep -v -- "-cloud" | sort -Vr | head -n 1)
|
||||
LATEST_STABLE_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+$" | sort -Vr | head -n 1)
|
||||
|
||||
# Handle case where no beta tags exist
|
||||
if [[ -z "$LATEST_BETA_TAG" ]]; then
|
||||
NEW_BETA_TAG="v1.0.0-beta.1"
|
||||
else
|
||||
NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 "." $3 "-beta." ($NF+1)}')
|
||||
fi
|
||||
|
||||
# Increment latest stable tag
|
||||
NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
|
||||
echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Echo branch and tag information
|
||||
run: |
|
||||
echo "Beta branch: ${{ steps.list-branches.outputs.beta }}"
|
||||
echo "Stable branch: ${{ steps.list-branches.outputs.stable }}"
|
||||
echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}"
|
||||
echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}"
|
||||
echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}"
|
||||
echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}"
|
||||
|
||||
- name: Trigger Backport
|
||||
if: steps.checkbox-check.outputs.backport == 'true'
|
||||
run: |
|
||||
set -e
|
||||
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
|
||||
|
||||
# Echo the merge commit SHA
|
||||
echo "Merge commit SHA: ${{ github.event.pull_request.merge_commit_sha }}"
|
||||
|
||||
# Fetch all history for all branches and tags
|
||||
git fetch --prune
|
||||
|
||||
# Reset and prepare the beta branch
|
||||
git checkout ${{ steps.list-branches.outputs.beta }}
|
||||
echo "Last 5 commits on beta branch:"
|
||||
git log -n 5 --pretty=format:"%H"
|
||||
echo "" # Newline for formatting
|
||||
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to beta failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Create new beta branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
# Push the changes and tag to the beta branch using PAT
|
||||
git push origin ${{ steps.list-branches.outputs.beta }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
|
||||
# Reset and prepare the stable branch
|
||||
git checkout ${{ steps.list-branches.outputs.stable }}
|
||||
echo "Last 5 commits on stable branch:"
|
||||
git log -n 5 --pretty=format:"%H"
|
||||
echo "" # Newline for formatting
|
||||
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to stable failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Create new stable branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
# Push the changes and tag to the stable branch using PAT
|
||||
git push origin ${{ steps.list-branches.outputs.stable }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
@@ -20,6 +20,7 @@ env:
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
71
.github/workflows/pr-helm-chart-testing.yml
vendored
71
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -65,35 +65,45 @@ jobs:
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo update
|
||||
|
||||
- name: Pre-pull critical images
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling critical images to avoid timeout ==="
|
||||
# Get kind cluster name
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
# Pre-pull images that are likely to be used
|
||||
echo "Pre-pulling PostgreSQL image..."
|
||||
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
|
||||
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
|
||||
|
||||
echo "Pre-pulling Redis image..."
|
||||
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
|
||||
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
|
||||
|
||||
echo "Pre-pulling Onyx images..."
|
||||
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
|
||||
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
|
||||
|
||||
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "Failed to pull $image"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -149,6 +159,7 @@ jobs:
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
@@ -156,8 +167,10 @@ jobs:
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.primary.persistence.enabled=false \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
@@ -173,8 +186,16 @@ jobs:
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
|
||||
echo "=== Installation completed successfully ==="
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
fi
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
@@ -199,7 +220,7 @@ jobs:
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
|
||||
7
.github/workflows/pr-integration-tests.yml
vendored
7
.github/workflows/pr-integration-tests.yml
vendored
@@ -22,9 +22,11 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
@@ -131,6 +133,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -158,6 +161,7 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -191,6 +195,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
@@ -337,9 +342,11 @@ jobs:
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
|
||||
@@ -19,9 +19,11 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
@@ -128,6 +130,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -155,6 +158,7 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -188,6 +192,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
@@ -334,9 +339,11 @@ jobs:
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
|
||||
35
.github/workflows/pr-playwright-tests.yml
vendored
35
.github/workflows/pr-playwright-tests.yml
vendored
@@ -56,6 +56,8 @@ jobs:
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
@@ -87,6 +89,8 @@ jobs:
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
@@ -118,6 +122,8 @@ jobs:
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
|
||||
playwright-tests:
|
||||
needs: [build-web-image, build-backend-image, build-model-server-image]
|
||||
@@ -179,16 +185,21 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
AUTH_TYPE=basic
|
||||
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }}
|
||||
EXA_API_KEY=${{ env.EXA_API_KEY }}
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
IMAGE_TAG=test
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }} \
|
||||
EXA_API_KEY=${{ env.EXA_API_KEY }} \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
|
||||
id: start_docker
|
||||
|
||||
@@ -228,14 +239,16 @@ jobs:
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./web
|
||||
run: npx playwright test
|
||||
run: |
|
||||
# Create test-results directory to ensure it exists for artifact upload
|
||||
mkdir -p test-results
|
||||
npx playwright test
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
# Chromatic automatically defaults to the test-results directory.
|
||||
# Replace with the path to your custom directory and adjust the CHROMATIC_ARCHIVE_LOCATION environment variable accordingly.
|
||||
name: test-results
|
||||
# Includes test results and debug screenshots
|
||||
name: playwright-test-results-${{ github.run_id }}
|
||||
path: ./web/test-results
|
||||
retention-days: 30
|
||||
|
||||
|
||||
@@ -20,11 +20,13 @@ env:
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
# Jira
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
|
||||
# Gong
|
||||
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
|
||||
|
||||
505
.vscode/launch.template.jsonc
vendored
505
.vscode/launch.template.jsonc
vendored
@@ -13,6 +13,50 @@
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": ["Web Server", "Model Server", "API Server"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
"stopAll": true
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
@@ -214,250 +258,6 @@
|
||||
"consoleTitle": "Celery docfetching Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": ["Web Server", "Model Server", "API Server"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
"stopAll": true
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": ["run", "dev"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": ["model_server.main:app", "--reload", "--port", "9000"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Model Server Console"
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": ["onyx.main:app", "--reload", "--port", "8080"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery primary Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery docfetching",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching,user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docfetching Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery docprocessing",
|
||||
"type": "debugpy",
|
||||
@@ -486,8 +286,83 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docprocessing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user file processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"--pool=threads",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user file processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
@@ -503,8 +378,8 @@
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
// Specify a specific module/test to run or provide nothing to run all tests
|
||||
// "tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
@@ -588,144 +463,6 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user file processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"--pool=threads",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user file processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_containers.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
// script to generate the openapi schema
|
||||
"name": "Onyx OpenAPI Schema Generator",
|
||||
|
||||
@@ -105,6 +105,11 @@ pip install -r backend/requirements/ee.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Fix vscode/cursor auto-imports:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
|
||||
@@ -128,7 +128,7 @@ def upgrade() -> None:
|
||||
AND a.attname = 'cc_pair_id'
|
||||
)
|
||||
) LOOP
|
||||
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT %I', r.conname);
|
||||
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT IF EXISTS %I', r.conname);
|
||||
END LOOP;
|
||||
END$$;
|
||||
"""
|
||||
|
||||
@@ -167,7 +167,10 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
# Delete related records
|
||||
# Clean child tables first to satisfy foreign key constraints,
|
||||
# then the parent tables
|
||||
tables_to_clean = [
|
||||
("index_attempt_errors", "connector_credential_pair_id"),
|
||||
("index_attempt", "connector_credential_pair_id"),
|
||||
("background_error", "cc_pair_id"),
|
||||
("document_set__connector_credential_pair", "connector_credential_pair_id"),
|
||||
@@ -242,7 +245,7 @@ def upgrade() -> None:
|
||||
AND t.relname = 'user_file'
|
||||
AND ft.relname = 'connector_credential_pair'
|
||||
) LOOP
|
||||
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT %I', r.conname);
|
||||
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT IF EXISTS %I', r.conname);
|
||||
END LOOP;
|
||||
END$$;
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add image input support to model config
|
||||
|
||||
Revision ID: 64bd5677aeb6
|
||||
Revises: b30353be4eec
|
||||
Create Date: 2025-09-28 15:48:12.003612
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "64bd5677aeb6"
|
||||
down_revision = "b30353be4eec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"model_configuration",
|
||||
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
# Seems to be left over from when model visibility was introduced and a nullable field.
|
||||
# Set any null is_visible values to False
|
||||
connection = op.get_bind()
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"UPDATE model_configuration SET is_visible = false WHERE is_visible IS NULL"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("model_configuration", "supports_image_input")
|
||||
123
backend/alembic/versions/b30353be4eec_add_mcp_auth_performer.py
Normal file
123
backend/alembic/versions/b30353be4eec_add_mcp_auth_performer.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""add_mcp_auth_performer
|
||||
|
||||
Revision ID: b30353be4eec
|
||||
Revises: 2b75d0a8ffcb
|
||||
Create Date: 2025-09-13 14:58:08.413534
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from onyx.db.enums import MCPAuthenticationPerformer, MCPTransport
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b30353be4eec"
|
||||
down_revision = "2b75d0a8ffcb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""moving to a better way of handling auth performer and transport"""
|
||||
# Add nullable column first for backward compatibility
|
||||
op.add_column(
|
||||
"mcp_server",
|
||||
sa.Column(
|
||||
"auth_performer",
|
||||
sa.Enum(MCPAuthenticationPerformer, native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"mcp_server",
|
||||
sa.Column(
|
||||
"transport",
|
||||
sa.Enum(MCPTransport, native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# # Backfill values using existing data and inference rules
|
||||
bind = op.get_bind()
|
||||
|
||||
# 1) OAUTH servers are always PER_USER
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE mcp_server
|
||||
SET auth_performer = 'PER_USER'
|
||||
WHERE auth_type = 'OAUTH'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 2) If there is no admin connection config, mark as ADMIN (and not set yet)
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE mcp_server
|
||||
SET auth_performer = 'ADMIN'
|
||||
WHERE admin_connection_config_id IS NULL
|
||||
AND auth_performer IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 3) If there exists any user-specific connection config (user_email != ''), mark as PER_USER
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE mcp_server AS ms
|
||||
SET auth_performer = 'PER_USER'
|
||||
FROM mcp_connection_config AS mcc
|
||||
WHERE mcc.mcp_server_id = ms.id
|
||||
AND COALESCE(mcc.user_email, '') <> ''
|
||||
AND ms.auth_performer IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 4) Default any remaining nulls to ADMIN (covers API_TOKEN admin-managed and NONE)
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE mcp_server
|
||||
SET auth_performer = 'ADMIN'
|
||||
WHERE auth_performer IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Finally, make the column non-nullable
|
||||
op.alter_column(
|
||||
"mcp_server",
|
||||
"auth_performer",
|
||||
existing_type=sa.Enum(MCPAuthenticationPerformer, native_enum=False),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Backfill transport for existing rows to STREAMABLE_HTTP, then make non-nullable
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE mcp_server
|
||||
SET transport = 'STREAMABLE_HTTP'
|
||||
WHERE transport IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"mcp_server",
|
||||
"transport",
|
||||
existing_type=sa.Enum(MCPTransport, native_enum=False),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""remove cols"""
|
||||
op.drop_column("mcp_server", "transport")
|
||||
op.drop_column("mcp_server", "auth_performer")
|
||||
@@ -124,9 +124,9 @@ def get_space_permission(
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"No permissions found for space '{space_key}'. This is very unlikely"
|
||||
"to be correct and is more likely caused by an access token with"
|
||||
"insufficient permissions. Make sure that the access token has Admin"
|
||||
f"No permissions found for space '{space_key}'. This is very unlikely "
|
||||
"to be correct and is more likely caused by an access token with "
|
||||
"insufficient permissions. Make sure that the access token has Admin "
|
||||
f"permissions for space '{space_key}'"
|
||||
)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def _get_slim_doc_generator(
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return gmail_connector.retrieve_all_slim_documents(
|
||||
return gmail_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
|
||||
@@ -34,7 +34,7 @@ def _get_slim_doc_generator(
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return google_drive_connector.retrieve_all_slim_documents(
|
||||
return google_drive_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
|
||||
@@ -59,7 +59,7 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
|
||||
|
||||
for raw_perm in permissions:
|
||||
if not hasattr(raw_perm, "raw"):
|
||||
logger.warn(f"Expected a 'raw' field, but none was found: {raw_perm=}")
|
||||
logger.warning(f"Expected a 'raw' field, but none was found: {raw_perm=}")
|
||||
continue
|
||||
|
||||
permission = Permission(**raw_perm.raw)
|
||||
@@ -71,14 +71,14 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
|
||||
# In order to associate this permission to some Atlassian entity, we need the "Holder".
|
||||
# If this doesn't exist, then we cannot associate this permission to anyone; just skip.
|
||||
if not permission.holder:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Expected to find a permission holder, but none was found: {permission=}"
|
||||
)
|
||||
continue
|
||||
|
||||
type = permission.holder.get("type")
|
||||
if not type:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Expected to find the type of permission holder, but none was found: {permission=}"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -105,7 +105,9 @@ def _get_slack_document_access(
|
||||
channel_permissions: dict[str, ExternalAccess],
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
callback=callback
|
||||
)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
|
||||
@@ -4,7 +4,7 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -17,7 +17,7 @@ def generic_doc_sync(
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
doc_source: DocumentSource,
|
||||
slim_connector: SlimConnector,
|
||||
slim_connector: SlimConnectorWithPermSync,
|
||||
label: str,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
@@ -40,7 +40,7 @@ def generic_doc_sync(
|
||||
newly_fetched_doc_ids: set[str] = set()
|
||||
|
||||
logger.info(f"Fetching all slim documents from {doc_source}")
|
||||
for doc_batch in slim_connector.retrieve_all_slim_documents(callback=callback):
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -16,6 +16,7 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
# saml
|
||||
("/auth/saml/authorize", {"GET"}),
|
||||
("/auth/saml/callback", {"POST"}),
|
||||
("/auth/saml/callback", {"GET"}),
|
||||
("/auth/saml/logout", {"POST"}),
|
||||
]
|
||||
|
||||
|
||||
@@ -110,7 +110,6 @@ async def upsert_saml_user(email: str) -> User:
|
||||
|
||||
|
||||
async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
|
||||
form_data = await request.form()
|
||||
if request.client is None:
|
||||
raise ValueError("Invalid request for SAML")
|
||||
|
||||
@@ -125,14 +124,27 @@ async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
|
||||
"post_data": {},
|
||||
"get_data": {},
|
||||
}
|
||||
|
||||
# Handle query parameters (for GET requests)
|
||||
if request.query_params:
|
||||
rv["get_data"] = (request.query_params,)
|
||||
if "SAMLResponse" in form_data:
|
||||
SAMLResponse = form_data["SAMLResponse"]
|
||||
rv["post_data"]["SAMLResponse"] = SAMLResponse
|
||||
if "RelayState" in form_data:
|
||||
RelayState = form_data["RelayState"]
|
||||
rv["post_data"]["RelayState"] = RelayState
|
||||
rv["get_data"] = dict(request.query_params)
|
||||
|
||||
# Handle form data (for POST requests)
|
||||
if request.method == "POST":
|
||||
form_data = await request.form()
|
||||
if "SAMLResponse" in form_data:
|
||||
SAMLResponse = form_data["SAMLResponse"]
|
||||
rv["post_data"]["SAMLResponse"] = SAMLResponse
|
||||
if "RelayState" in form_data:
|
||||
RelayState = form_data["RelayState"]
|
||||
rv["post_data"]["RelayState"] = RelayState
|
||||
else:
|
||||
# For GET requests, check if SAMLResponse is in query params
|
||||
if "SAMLResponse" in request.query_params:
|
||||
rv["get_data"]["SAMLResponse"] = request.query_params["SAMLResponse"]
|
||||
if "RelayState" in request.query_params:
|
||||
rv["get_data"]["RelayState"] = request.query_params["RelayState"]
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
@@ -148,10 +160,27 @@ async def saml_login(request: Request) -> SAMLAuthorizeResponse:
|
||||
return SAMLAuthorizeResponse(authorization_url=callback_url)
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def saml_login_callback_get(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Handle SAML callback via HTTP-Redirect binding (GET request)"""
|
||||
return await _process_saml_callback(request, db_session)
|
||||
|
||||
|
||||
@router.post("/callback")
|
||||
async def saml_login_callback(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Handle SAML callback via HTTP-POST binding (POST request)"""
|
||||
return await _process_saml_callback(request, db_session)
|
||||
|
||||
|
||||
async def _process_saml_callback(
|
||||
request: Request,
|
||||
db_session: Session,
|
||||
) -> Response:
|
||||
req = await prepare_from_fastapi_request(request)
|
||||
auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR)
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Optional
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from litellm.exceptions import RateLimitError
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
@@ -207,6 +206,8 @@ async def route_bi_encoder_embed(
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
||||
) -> EmbedResponse:
|
||||
from litellm.exceptions import RateLimitError
|
||||
|
||||
# Only local models should use this endpoint - API providers should make direct API calls
|
||||
if embed_request.provider_type is not None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -24,6 +24,8 @@ def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
return END
|
||||
elif next_tool_name == DRPath.LOGGER.value:
|
||||
return DRPath.LOGGER
|
||||
elif next_tool_name == DRPath.CLOSER.value:
|
||||
return DRPath.CLOSER
|
||||
else:
|
||||
return DRPath.ORCHESTRATOR
|
||||
|
||||
|
||||
@@ -643,9 +643,16 @@ def clarifier(
|
||||
datetime_aware=True,
|
||||
)
|
||||
|
||||
system_prompt_to_use = build_citations_system_message(
|
||||
system_prompt_to_use_content = build_citations_system_message(
|
||||
prompt_config
|
||||
).content
|
||||
system_prompt_to_use: str = cast(str, system_prompt_to_use_content)
|
||||
if graph_config.inputs.project_instructions:
|
||||
system_prompt_to_use = (
|
||||
system_prompt_to_use
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ graph_config.inputs.project_instructions
|
||||
)
|
||||
user_prompt_to_use = build_citations_user_message(
|
||||
user_query=original_question,
|
||||
files=[],
|
||||
|
||||
@@ -181,6 +181,15 @@ def orchestrator(
|
||||
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[research_type]
|
||||
|
||||
elif remaining_time_budget <= 0:
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.CLOSER.value],
|
||||
current_step_nr=current_step_nr,
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
SERPER_SEARCH_URL = "https://google.serper.dev/search"
|
||||
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
|
||||
|
||||
|
||||
class SerperClient(InternetSearchProvider):
|
||||
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
|
||||
self.headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[InternetSearchResult]:
|
||||
payload = {
|
||||
"q": query,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
SERPER_SEARCH_URL,
|
||||
headers=self.headers,
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
organic_results = results["organic"]
|
||||
|
||||
return [
|
||||
InternetSearchResult(
|
||||
title=result["title"],
|
||||
link=result["link"],
|
||||
snippet=result["snippet"],
|
||||
author=None,
|
||||
published_date=None,
|
||||
)
|
||||
for result in organic_results
|
||||
]
|
||||
|
||||
def contents(self, urls: list[str]) -> list[InternetContent]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
# Serper can responds with 500s regularly. We want to retry,
|
||||
# but in the event of failure, return an unsuccesful scrape.
|
||||
def safe_get_webpage_content(url: str) -> InternetContent:
|
||||
try:
|
||||
return self._get_webpage_content(url)
|
||||
except Exception:
|
||||
return InternetContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
|
||||
return list(e.map(safe_get_webpage_content, urls))
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def _get_webpage_content(self, url: str) -> InternetContent:
|
||||
payload = {
|
||||
"url": url,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
SERPER_CONTENTS_URL,
|
||||
headers=self.headers,
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
|
||||
# 400 returned when serper cannot scrape
|
||||
if response.status_code == 400:
|
||||
return InternetContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Response only guarantees text
|
||||
text = response_json["text"]
|
||||
|
||||
# metadata & jsonld is not guaranteed to be present
|
||||
metadata = response_json.get("metadata", {})
|
||||
jsonld = response_json.get("jsonld", {})
|
||||
|
||||
title = extract_title_from_metadata(metadata)
|
||||
|
||||
# Serper does not provide a reliable mechanism to extract the url
|
||||
response_url = url
|
||||
published_date_str = extract_published_date_from_jsonld(jsonld)
|
||||
published_date = None
|
||||
|
||||
if published_date_str:
|
||||
try:
|
||||
published_date = time_str_to_utc(published_date_str)
|
||||
except Exception:
|
||||
published_date = None
|
||||
|
||||
return InternetContent(
|
||||
title=title or "",
|
||||
link=response_url,
|
||||
full_content=text or "",
|
||||
published_date=published_date,
|
||||
)
|
||||
|
||||
|
||||
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
|
||||
keys = ["title", "og:title"]
|
||||
return extract_value_from_dict(metadata, keys)
|
||||
|
||||
|
||||
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
|
||||
keys = ["dateModified"]
|
||||
return extract_value_from_dict(jsonld, keys)
|
||||
|
||||
|
||||
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
|
||||
for key in keys:
|
||||
if key in data:
|
||||
return data[key]
|
||||
return None
|
||||
@@ -26,6 +26,7 @@ class InternetContent(BaseModel):
|
||||
link: str
|
||||
full_content: str
|
||||
published_date: datetime | None = None
|
||||
scrape_successful: bool = True
|
||||
|
||||
|
||||
class InternetSearchProvider(ABC):
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
|
||||
ExaClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
|
||||
SerperClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
|
||||
|
||||
def get_default_provider() -> InternetSearchProvider | None:
|
||||
if EXA_API_KEY:
|
||||
return ExaClient()
|
||||
if SERPER_API_KEY:
|
||||
return SerperClient()
|
||||
return None
|
||||
|
||||
@@ -34,7 +34,7 @@ def dummy_inference_section_from_internet_content(
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=1.0,
|
||||
hidden=False,
|
||||
hidden=(not result.scrape_successful),
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary=truncated_content,
|
||||
|
||||
@@ -8,8 +8,6 @@ from typing import TypeVar
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from litellm import get_supported_openai_params
|
||||
from litellm import supports_response_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
@@ -147,6 +145,7 @@ def invoke_llm_json(
|
||||
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
|
||||
and return an object of that schema.
|
||||
"""
|
||||
from litellm.utils import get_supported_openai_params, supports_response_schema
|
||||
|
||||
# check if the model supports response_format: json_schema
|
||||
supports_json = "response_format" in (
|
||||
|
||||
@@ -19,7 +19,9 @@ from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -30,7 +32,7 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: Iterator[list[Document]],
|
||||
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
|
||||
) -> Generator[set[str], None, None]:
|
||||
for doc_list in doc_batch:
|
||||
yield {doc.id for doc in doc_list}
|
||||
@@ -41,20 +43,24 @@ def extract_ids_from_runnable_connector(
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the SlimConnector hasnt been implemented for the given connector, just pull
|
||||
If the given connector is neither a SlimConnector nor a SlimConnectorWithPermSync, just pull
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
if isinstance(runnable_connector, SlimConnector):
|
||||
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
|
||||
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
|
||||
|
||||
doc_batch_id_generator = None
|
||||
|
||||
if isinstance(runnable_connector, LoadConnector):
|
||||
if isinstance(runnable_connector, SlimConnector):
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.retrieve_all_slim_docs()
|
||||
)
|
||||
elif isinstance(runnable_connector, SlimConnectorWithPermSync):
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.retrieve_all_slim_docs_perm_sync()
|
||||
)
|
||||
# If the connector isn't slim, fall back to running it normally to get ids
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.load_from_state()
|
||||
)
|
||||
@@ -78,13 +84,14 @@ def extract_ids_from_runnable_connector(
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
# this function is called per batch for rate limiting
|
||||
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
|
||||
return doc_batch_ids
|
||||
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
doc_batch_processing_func = (
|
||||
rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(lambda x: x)
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
for doc_batch_ids in doc_batch_id_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
|
||||
@@ -41,7 +41,7 @@ beat_task_templates: list[dict] = [
|
||||
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
},
|
||||
@@ -85,9 +85,9 @@ beat_task_templates: list[dict] = [
|
||||
{
|
||||
"name": "check-for-index-attempt-cleanup",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP,
|
||||
"schedule": timedelta(hours=1),
|
||||
"schedule": timedelta(minutes=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -89,6 +89,7 @@ from onyx.indexing.adapters.document_indexing_adapter import (
|
||||
DocumentIndexingBatchAdapter,
|
||||
)
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
@@ -1270,8 +1271,6 @@ def _docprocessing_task(
|
||||
tenant_id: str,
|
||||
batch_num: int,
|
||||
) -> None:
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
if tenant_id:
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
from tenacity import RetryError
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pairs_with_user_files,
|
||||
)
|
||||
from onyx.db.document import get_document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.user_documents import fetch_user_files_for_documents
|
||||
from onyx.db.user_documents import fetch_user_folders_for_documents
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_user_file_folder_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Runs periodically to check for documents that need user file folder metadata updates.
|
||||
This task fetches all connector credential pairs with user files, gets the documents
|
||||
associated with them, and updates the user file and folder metadata in Vespa.
|
||||
"""
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all connector credential pairs that have user files
|
||||
cc_pairs = get_connector_credential_pairs_with_user_files(db_session)
|
||||
|
||||
if not cc_pairs:
|
||||
task_logger.info("No connector credential pairs with user files found")
|
||||
return True
|
||||
|
||||
# Get all documents associated with these cc_pairs
|
||||
document_ids = get_documents_for_cc_pairs(cc_pairs, db_session)
|
||||
|
||||
if not document_ids:
|
||||
task_logger.info(
|
||||
"No documents found for connector credential pairs with user files"
|
||||
)
|
||||
return True
|
||||
|
||||
# Fetch current user file and folder IDs for these documents
|
||||
doc_id_to_user_file_id = fetch_user_files_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
doc_id_to_user_folder_id = fetch_user_folders_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
|
||||
# Update Vespa metadata for each document
|
||||
for doc_id in document_ids:
|
||||
user_file_id = doc_id_to_user_file_id.get(doc_id)
|
||||
user_folder_id = doc_id_to_user_folder_id.get(doc_id)
|
||||
|
||||
if user_file_id is not None or user_folder_id is not None:
|
||||
# Schedule a task to update the document metadata
|
||||
update_user_file_folder_metadata.apply_async(
|
||||
args=(doc_id,), # Use tuple instead of list for args
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
"user_file_id": user_file_id,
|
||||
"user_folder_id": user_folder_id,
|
||||
},
|
||||
queue="vespa_metadata_sync",
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Scheduled metadata updates for {len(document_ids)} documents. "
|
||||
f"Elapsed time: {time.monotonic() - time_start:.2f}s"
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Error in check_for_user_file_folder_sync: {e}")
|
||||
return False
|
||||
finally:
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def get_documents_for_cc_pairs(
|
||||
cc_pairs: List[ConnectorCredentialPair], db_session: Session
|
||||
) -> List[str]:
|
||||
"""Get all document IDs associated with the given connector credential pairs."""
|
||||
if not cc_pairs:
|
||||
return []
|
||||
|
||||
cc_pair_ids = [cc_pair.id for cc_pair in cc_pairs]
|
||||
|
||||
# Query to get document IDs from DocumentByConnectorCredentialPair
|
||||
# Note: DocumentByConnectorCredentialPair uses connector_id and credential_id, not cc_pair_id
|
||||
doc_cc_pairs = (
|
||||
db_session.query(Document.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.filter(
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(
|
||||
ConnectorCredentialPair.id.in_(cc_pair_ids),
|
||||
ConnectorCredentialPair.connector_id
|
||||
== DocumentByConnectorCredentialPair.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
== DocumentByConnectorCredentialPair.credential_id,
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [doc_id for (doc_id,) in doc_cc_pairs]
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.UPDATE_USER_FILE_FOLDER_METADATA,
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=3,
|
||||
)
|
||||
def update_user_file_folder_metadata(
|
||||
self: Task,
|
||||
document_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_file_id: int | None,
|
||||
user_folder_id: int | None,
|
||||
) -> bool:
|
||||
"""Updates the user file and folder metadata for a document in Vespa."""
|
||||
start = time.monotonic()
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=no_operation "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
|
||||
return False
|
||||
|
||||
# Create user fields object with file and folder IDs
|
||||
user_fields = VespaDocumentUserFields(
|
||||
user_file_id=str(user_file_id) if user_file_id is not None else None,
|
||||
user_folder_id=(
|
||||
str(user_folder_id) if user_folder_id is not None else None
|
||||
),
|
||||
)
|
||||
|
||||
# Update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=None, # We're only updating user fields
|
||||
user_fields=user_fields,
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=user_file_folder_sync "
|
||||
f"user_file_id={user_file_id} "
|
||||
f"user_folder_id={user_folder_id} "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
return True
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
|
||||
except Exception as ex:
|
||||
e: Exception | None = None
|
||||
while True:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
task_logger.exception(
|
||||
f"update_user_file_folder_metadata exceptioned: doc={document_id}"
|
||||
)
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
|
||||
if (
|
||||
self.max_retries is not None
|
||||
and self.request.retries >= self.max_retries
|
||||
):
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
|
||||
break # we won't hit this, but it looks weird not to have it
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"update_user_file_folder_metadata completed: status={completion_status.value} doc={document_id}"
|
||||
)
|
||||
|
||||
return False
|
||||
@@ -236,7 +236,11 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
f"process_single_user_file - Indexing pipeline completed ={index_pipeline_result}"
|
||||
)
|
||||
|
||||
if index_pipeline_result.failures:
|
||||
if (
|
||||
index_pipeline_result.failures
|
||||
or index_pipeline_result.total_docs != len(documents)
|
||||
or index_pipeline_result.total_chunks == 0
|
||||
):
|
||||
task_logger.error(
|
||||
f"process_single_user_file - Indexing pipeline failed id={user_file_id}"
|
||||
)
|
||||
@@ -542,39 +546,78 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
task_logger.warning(
|
||||
f"Tenant={tenant_id} failed Vespa update for doc_id={new_uuid} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Update search_doc records to refer to the UUID string
|
||||
uf_id_subq = (
|
||||
sa.select(sa.cast(UserFile.id, sa.String))
|
||||
.where(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
SearchDoc.document_id == UserFile.document_id,
|
||||
# we are not using document_id_migrated = false because if the migration already completed,
|
||||
# it will not run again and we will not update the search_doc records because of the issue currently fixed
|
||||
user_files = (
|
||||
db_session.execute(
|
||||
sa.select(UserFile).where(UserFile.document_id.is_not(None))
|
||||
)
|
||||
.correlate(SearchDoc)
|
||||
.scalar_subquery()
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
db_session.execute(
|
||||
sa.update(SearchDoc)
|
||||
.where(
|
||||
sa.exists(
|
||||
sa.select(sa.literal(1)).where(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
SearchDoc.document_id == UserFile.document_id,
|
||||
)
|
||||
|
||||
# Query all SearchDocs that need updating
|
||||
search_docs = (
|
||||
db_session.execute(
|
||||
sa.select(SearchDoc).where(
|
||||
SearchDoc.document_id.like("%FILE_CONNECTOR__%")
|
||||
)
|
||||
)
|
||||
.values(document_id=uf_id_subq)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
# Mark all processed user_files as migrated
|
||||
db_session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
|
||||
task_logger.info(f"Found {len(user_files)} user files to update")
|
||||
task_logger.info(f"Found {len(search_docs)} search docs to update")
|
||||
|
||||
# Build a map of normalized doc IDs to SearchDocs
|
||||
search_doc_map: dict[str, list[SearchDoc]] = {}
|
||||
for sd in search_docs:
|
||||
doc_id = sd.document_id
|
||||
if search_doc_map.get(doc_id) is None:
|
||||
search_doc_map[doc_id] = []
|
||||
search_doc_map[doc_id].append(sd)
|
||||
|
||||
task_logger.debug(
|
||||
f"Built search doc map with {len(search_doc_map)} entries"
|
||||
)
|
||||
ids_preview = list(search_doc_map.keys())[:5]
|
||||
task_logger.debug(
|
||||
f"First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
|
||||
)
|
||||
task_logger.debug(
|
||||
f"search_doc_map total items: {sum(len(docs) for docs in search_doc_map.values())}"
|
||||
)
|
||||
# Process each UserFile and update matching SearchDocs
|
||||
updated_count = 0
|
||||
for uf in user_files:
|
||||
doc_id = uf.document_id
|
||||
if doc_id.startswith("USER_FILE_CONNECTOR__"):
|
||||
doc_id = "FILE_CONNECTOR__" + doc_id[len("USER_FILE_CONNECTOR__") :]
|
||||
|
||||
task_logger.debug(f"Processing user file {uf.id} with doc_id {doc_id}")
|
||||
task_logger.debug(
|
||||
f"doc_id in search_doc_map: {doc_id in search_doc_map}"
|
||||
)
|
||||
.values(document_id_migrated=True)
|
||||
|
||||
if doc_id in search_doc_map:
|
||||
search_docs = search_doc_map[doc_id]
|
||||
task_logger.debug(
|
||||
f"Found {len(search_docs)} search docs to update for user file {uf.id}"
|
||||
)
|
||||
# Update the SearchDoc to use the UserFile's UUID
|
||||
for search_doc in search_docs:
|
||||
search_doc.document_id = str(uf.id)
|
||||
db_session.add(search_doc)
|
||||
|
||||
# Mark UserFile as migrated
|
||||
uf.document_id_migrated = True
|
||||
db_session.add(uf)
|
||||
updated_count += 1
|
||||
|
||||
task_logger.info(
|
||||
f"Updated {updated_count} SearchDoc records with new UUIDs"
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import NUM_DAYS_TO_KEEP_INDEX_ATTEMPTS
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
|
||||
|
||||
def get_old_index_attempts(
|
||||
@@ -21,6 +22,10 @@ def get_old_index_attempts(
|
||||
|
||||
def cleanup_index_attempts(db_session: Session, index_attempt_ids: list[int]) -> None:
|
||||
"""Clean up multiple index attempts"""
|
||||
db_session.query(IndexAttemptError).filter(
|
||||
IndexAttemptError.index_attempt_id.in_(index_attempt_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db_session.query(IndexAttempt).filter(
|
||||
IndexAttempt.id.in_(index_attempt_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
@@ -101,7 +102,6 @@ def _get_connector_runner(
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
|
||||
task = attempt.connector_credential_pair.connector.input_type
|
||||
|
||||
|
||||
@@ -138,23 +138,34 @@ def _build_project_llm_docs(
|
||||
|
||||
project_file_id_set = set(project_file_ids)
|
||||
for f in in_memory_user_files:
|
||||
# Only include files that belong to the project (not ad-hoc uploads)
|
||||
if project_file_id_set and (f.file_id in project_file_id_set):
|
||||
try:
|
||||
text_content = f.content.decode("utf-8", errors="ignore")
|
||||
except Exception:
|
||||
text_content = ""
|
||||
|
||||
# Build a short blurb from the file content for better UI display
|
||||
blurb = (
|
||||
(text_content[:200] + "...")
|
||||
if len(text_content) > 200
|
||||
else text_content
|
||||
)
|
||||
def _strip_nuls(s: str) -> str:
|
||||
return s.replace("\x00", "") if s else s
|
||||
|
||||
cleaned_filename = _strip_nuls(f.filename or str(f.file_id))
|
||||
|
||||
if f.file_type.is_text_file():
|
||||
try:
|
||||
text_content = f.content.decode("utf-8", errors="ignore")
|
||||
text_content = _strip_nuls(text_content)
|
||||
except Exception:
|
||||
text_content = ""
|
||||
|
||||
# Build a short blurb from the file content for better UI display
|
||||
blurb = (
|
||||
(text_content[:200] + "...")
|
||||
if len(text_content) > 200
|
||||
else text_content
|
||||
)
|
||||
else:
|
||||
# Non-text (e.g., images): do not decode bytes; keep empty content but allow citation
|
||||
text_content = ""
|
||||
blurb = f"[{f.file_type.value}] {cleaned_filename}"
|
||||
|
||||
# Provide basic metadata to improve SavedSearchDoc display
|
||||
file_metadata: dict[str, str | list[str]] = {
|
||||
"filename": f.filename or str(f.file_id),
|
||||
"filename": cleaned_filename,
|
||||
"file_type": f.file_type.value,
|
||||
}
|
||||
|
||||
@@ -163,7 +174,7 @@ def _build_project_llm_docs(
|
||||
document_id=str(f.file_id),
|
||||
content=text_content,
|
||||
blurb=blurb,
|
||||
semantic_identifier=f.filename or str(f.file_id),
|
||||
semantic_identifier=cleaned_filename,
|
||||
source_type=DocumentSource.USER_FILE,
|
||||
metadata=file_metadata,
|
||||
updated_at=None,
|
||||
@@ -503,7 +514,8 @@ def stream_chat_message_objects(
|
||||
for fd in new_msg_req.current_message_files:
|
||||
uid = fd.get("user_file_id")
|
||||
if uid is not None:
|
||||
user_file_ids.append(uid)
|
||||
user_file_id = UUID(uid)
|
||||
user_file_ids.append(user_file_id)
|
||||
|
||||
# Load in user files into memory and create search tool override kwargs if needed
|
||||
# if we have enough tokens, we don't need to use search
|
||||
|
||||
@@ -100,12 +100,14 @@ def parse_user_files(
|
||||
persona=persona,
|
||||
actual_user_input=actual_user_input,
|
||||
)
|
||||
uploaded_context_cap = int(available_tokens * 0.5)
|
||||
|
||||
logger.debug(
|
||||
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens}"
|
||||
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens},"
|
||||
f"Allowed uploaded context tokens: {uploaded_context_cap}"
|
||||
)
|
||||
|
||||
have_enough_tokens = total_tokens <= available_tokens
|
||||
have_enough_tokens = total_tokens <= uploaded_context_cap
|
||||
|
||||
# If we have enough tokens, we don't need search
|
||||
# we can just pass them into the prompt directly
|
||||
|
||||
@@ -90,6 +90,7 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
|
||||
|
||||
# Internet Search
|
||||
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
|
||||
SERPER_API_KEY = os.environ.get("SERPER_API_KEY") or None
|
||||
|
||||
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
|
||||
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)
|
||||
|
||||
@@ -41,7 +41,7 @@ All new connectors should have tests added to the `backend/tests/daily/connector
|
||||
|
||||
#### Implementing the new Connector
|
||||
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, CheckpointedConnector, or CheckpointedConnectorWithPermSync
|
||||
|
||||
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
|
||||
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of
|
||||
|
||||
@@ -25,7 +25,7 @@ from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -56,7 +56,7 @@ class BitbucketConnectorCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
class BitbucketConnector(
|
||||
CheckpointedConnector[BitbucketConnectorCheckpoint],
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
"""Connector for indexing Bitbucket Cloud pull requests.
|
||||
|
||||
@@ -266,7 +266,7 @@ class BitbucketConnector(
|
||||
"""Validate and deserialize a checkpoint instance from JSON."""
|
||||
return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian.errors import ApiError # type: ignore
|
||||
from requests.exceptions import HTTPError
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -41,6 +42,7 @@ from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
@@ -91,6 +93,7 @@ class ConfluenceCheckpoint(ConnectorCheckpoint):
|
||||
class ConfluenceConnector(
|
||||
CheckpointedConnector[ConfluenceCheckpoint],
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
CredentialsConnector,
|
||||
):
|
||||
def __init__(
|
||||
@@ -108,6 +111,7 @@ class ConfluenceConnector(
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
scoped_token: bool = False,
|
||||
) -> None:
|
||||
self.wiki_base = wiki_base
|
||||
self.is_cloud = is_cloud
|
||||
@@ -118,6 +122,7 @@ class ConfluenceConnector(
|
||||
self.batch_size = batch_size
|
||||
self.labels_to_skip = labels_to_skip
|
||||
self.timezone_offset = timezone_offset
|
||||
self.scoped_token = scoped_token
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
@@ -195,6 +200,7 @@ class ConfluenceConnector(
|
||||
is_cloud=self.is_cloud,
|
||||
url=self.wiki_base,
|
||||
credentials_provider=credentials_provider,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
confluence_client._probe_connection(**self.probe_kwargs)
|
||||
confluence_client._initialize_connection(**self.final_kwargs)
|
||||
@@ -207,6 +213,7 @@ class ConfluenceConnector(
|
||||
url=self.wiki_base,
|
||||
credentials_provider=credentials_provider,
|
||||
timeout=3,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
low_timeout_confluence_client._probe_connection(**self.probe_kwargs)
|
||||
low_timeout_confluence_client._initialize_connection(**self.final_kwargs)
|
||||
@@ -558,7 +565,21 @@ class ConfluenceConnector(
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConfluenceCheckpoint:
|
||||
return ConfluenceCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
@override
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_all_slim_docs(
|
||||
start=start,
|
||||
end=end,
|
||||
callback=callback,
|
||||
include_permissions=False,
|
||||
)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -568,12 +589,28 @@ class ConfluenceConnector(
|
||||
Return 'slim' docs (IDs + minimal permission data).
|
||||
Does not fetch actual text. Used primarily for incremental permission sync.
|
||||
"""
|
||||
return self._retrieve_all_slim_docs(
|
||||
start=start,
|
||||
end=end,
|
||||
callback=callback,
|
||||
include_permissions=True,
|
||||
)
|
||||
|
||||
def _retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
space_level_access_info = get_all_space_permissions(
|
||||
self.confluence_client, self.is_cloud
|
||||
)
|
||||
space_level_access_info: dict[str, ExternalAccess] = {}
|
||||
if include_permissions:
|
||||
space_level_access_info = get_all_space_permissions(
|
||||
self.confluence_client, self.is_cloud
|
||||
)
|
||||
|
||||
def get_external_access(
|
||||
doc_id: str, restrictions: dict[str, Any], ancestors: list[dict[str, Any]]
|
||||
@@ -600,8 +637,10 @@ class ConfluenceConnector(
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=page_id,
|
||||
external_access=get_external_access(
|
||||
page_id, page_restrictions, page_ancestors
|
||||
external_access=(
|
||||
get_external_access(page_id, page_restrictions, page_ancestors)
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -636,8 +675,12 @@ class ConfluenceConnector(
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=attachment_id,
|
||||
external_access=get_external_access(
|
||||
attachment_id, attachment_restrictions, []
|
||||
external_access=(
|
||||
get_external_access(
|
||||
attachment_id, attachment_restrictions, []
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -648,10 +691,10 @@ class ConfluenceConnector(
|
||||
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
|
||||
)
|
||||
if callback:
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
@@ -676,6 +719,14 @@ class ConfluenceConnector(
|
||||
f"Unexpected error while validating Confluence settings: {e}"
|
||||
)
|
||||
|
||||
if self.space:
|
||||
try:
|
||||
self.low_timeout_confluence_client.get_space(self.space)
|
||||
except ApiError as e:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid Confluence space key provided"
|
||||
) from e
|
||||
|
||||
if not spaces or not spaces.get("results"):
|
||||
raise ConnectorValidationError(
|
||||
"No Confluence spaces found. Either your credentials lack permissions, or "
|
||||
@@ -724,7 +775,7 @@ if __name__ == "__main__":
|
||||
end = datetime.now().timestamp()
|
||||
|
||||
# Fetch all `SlimDocuments`.
|
||||
for slim_doc in confluence_connector.retrieve_all_slim_documents():
|
||||
for slim_doc in confluence_connector.retrieve_all_slim_docs_perm_sync():
|
||||
print(slim_doc)
|
||||
|
||||
# Fetch all `Documents`.
|
||||
|
||||
@@ -41,6 +41,7 @@ from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -87,16 +88,20 @@ class OnyxConfluence:
|
||||
url: str,
|
||||
credentials_provider: CredentialsProviderInterface,
|
||||
timeout: int | None = None,
|
||||
scoped_token: bool = False,
|
||||
# should generally not be passed in, but making it overridable for
|
||||
# easier testing
|
||||
confluence_user_profiles_override: list[dict[str, str]] | None = (
|
||||
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
|
||||
),
|
||||
) -> None:
|
||||
self.base_url = url #'/'.join(url.rstrip("/").split("/")[:-1])
|
||||
url = scoped_url(url, "confluence") if scoped_token else url
|
||||
|
||||
self._is_cloud = is_cloud
|
||||
self._url = url.rstrip("/")
|
||||
self._credentials_provider = credentials_provider
|
||||
|
||||
self.scoped_token = scoped_token
|
||||
self.redis_client: Redis | None = None
|
||||
self.static_credentials: dict[str, Any] | None = None
|
||||
if self._credentials_provider.is_dynamic():
|
||||
@@ -218,6 +223,34 @@ class OnyxConfluence:
|
||||
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
if self.scoped_token:
|
||||
# v2 endpoint doesn't always work with scoped tokens, use v1
|
||||
token = credentials["confluence_access_token"]
|
||||
probe_url = f"{self.base_url}/rest/api/space?limit=1"
|
||||
import requests
|
||||
|
||||
logger.info(f"First and Last 5 of token: {token[:5]}...{token[-5:]}")
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
probe_url,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=10,
|
||||
)
|
||||
r.raise_for_status()
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
logger.warning(
|
||||
"scoped token authenticated but not valid for probe endpoint (spaces)"
|
||||
)
|
||||
else:
|
||||
if "WWW-Authenticate" in e.response.headers:
|
||||
logger.warning(
|
||||
f"WWW-Authenticate: {e.response.headers['WWW-Authenticate']}"
|
||||
)
|
||||
logger.warning(f"Full error: {e.response.text}")
|
||||
raise e
|
||||
return
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
@@ -236,6 +269,7 @@ class OnyxConfluence:
|
||||
logger.info("Probing Confluence with Personal Access Token.")
|
||||
url = self._url
|
||||
if self._is_cloud:
|
||||
logger.info("running with cloud client")
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
username=credentials["confluence_username"],
|
||||
@@ -304,7 +338,9 @@ class OnyxConfluence:
|
||||
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
|
||||
else:
|
||||
logger.info("Connecting to Confluence with Personal Access Token.")
|
||||
logger.info(
|
||||
f"Connecting to Confluence with Personal Access Token as user: {credentials['confluence_username']}"
|
||||
)
|
||||
if self._is_cloud:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
|
||||
@@ -5,7 +5,10 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from dateutil.parser import parse
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
@@ -148,3 +151,17 @@ def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
|
||||
|
||||
def is_atlassian_date_error(e: Exception) -> bool:
|
||||
return "field 'updated' is invalid" in str(e)
|
||||
|
||||
|
||||
def get_cloudId(base_url: str) -> str:
|
||||
tenant_info_url = urljoin(base_url, "/_edge/tenant_info")
|
||||
response = requests.get(tenant_info_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.json()["cloudId"]
|
||||
|
||||
|
||||
def scoped_url(url: str, product: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
base_url = parsed.scheme + "://" + parsed.netloc
|
||||
cloud_id = get_cloudId(base_url)
|
||||
return f"https://api.atlassian.com/ex/{product}/{cloud_id}{parsed.path}"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import importlib
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
@@ -6,60 +7,16 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
from onyx.connectors.bitbucket.connector import BitbucketConnector
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
from onyx.connectors.dropbox.connector import DropboxConnector
|
||||
from onyx.connectors.egnyte.connector import EgnyteConnector
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.fireflies.connector import FirefliesConnector
|
||||
from onyx.connectors.freshdesk.connector import FreshdeskConnector
|
||||
from onyx.connectors.gitbook.connector import GitbookConnector
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.gitlab.connector import GitlabConnector
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.gong.connector import GongConnector
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_site.connector import GoogleSitesConnector
|
||||
from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.highspot.connector import HighspotConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.imap.connector import ImapConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.jira.connector import JiraConnector
|
||||
from onyx.connectors.linear.connector import LinearConnector
|
||||
from onyx.connectors.loopio.connector import LoopioConnector
|
||||
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
|
||||
from onyx.connectors.mock_connector.connector import MockConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.notion.connector import NotionConnector
|
||||
from onyx.connectors.outline.connector import OutlineConnector
|
||||
from onyx.connectors.productboard.connector import ProductboardConnector
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.slab.connector import SlabConnector
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.connectors.web.connector import WebConnector
|
||||
from onyx.connectors.wikipedia.connector import WikipediaConnector
|
||||
from onyx.connectors.xenforo.connector import XenforoConnector
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from onyx.connectors.zulip.connector import ZulipConnector
|
||||
from onyx.connectors.registry import CONNECTOR_CLASS_MAP
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
@@ -72,101 +29,75 @@ class ConnectorMissingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Cache for already imported connector classes
|
||||
_connector_cache: dict[DocumentSource, Type[BaseConnector]] = {}
|
||||
|
||||
|
||||
def _load_connector_class(source: DocumentSource) -> Type[BaseConnector]:
|
||||
"""Dynamically load and cache a connector class."""
|
||||
if source in _connector_cache:
|
||||
return _connector_cache[source]
|
||||
|
||||
if source not in CONNECTOR_CLASS_MAP:
|
||||
raise ConnectorMissingException(f"Connector not found for source={source}")
|
||||
|
||||
mapping = CONNECTOR_CLASS_MAP[source]
|
||||
|
||||
try:
|
||||
module = importlib.import_module(mapping.module_path)
|
||||
connector_class = getattr(module, mapping.class_name)
|
||||
_connector_cache[source] = connector_class
|
||||
return connector_class
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ConnectorMissingException(
|
||||
f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_connector_supports_input_type(
|
||||
connector: Type[BaseConnector],
|
||||
input_type: InputType | None,
|
||||
source: DocumentSource,
|
||||
) -> None:
|
||||
"""Validate that a connector supports the requested input type."""
|
||||
if input_type is None:
|
||||
return
|
||||
|
||||
# Check each input type requirement separately for clarity
|
||||
load_state_unsupported = input_type == InputType.LOAD_STATE and not issubclass(
|
||||
connector, LoadConnector
|
||||
)
|
||||
|
||||
poll_unsupported = (
|
||||
input_type == InputType.POLL
|
||||
# Either poll or checkpoint works for this, in the future
|
||||
# all connectors should be checkpoint connectors
|
||||
and (
|
||||
not issubclass(connector, PollConnector)
|
||||
and not issubclass(connector, CheckpointedConnector)
|
||||
)
|
||||
)
|
||||
|
||||
event_unsupported = input_type == InputType.EVENT and not issubclass(
|
||||
connector, EventConnector
|
||||
)
|
||||
|
||||
if any([load_state_unsupported, poll_unsupported, event_unsupported]):
|
||||
raise ConnectorMissingException(
|
||||
f"Connector for source={source} does not accept input_type={input_type}"
|
||||
)
|
||||
|
||||
|
||||
def identify_connector_class(
|
||||
source: DocumentSource,
|
||||
input_type: InputType | None = None,
|
||||
) -> Type[BaseConnector]:
|
||||
connector_map = {
|
||||
DocumentSource.WEB: WebConnector,
|
||||
DocumentSource.FILE: LocalFileConnector,
|
||||
DocumentSource.SLACK: {
|
||||
InputType.POLL: SlackConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
DocumentSource.GITLAB: GitlabConnector,
|
||||
DocumentSource.GITBOOK: GitbookConnector,
|
||||
DocumentSource.GOOGLE_DRIVE: GoogleDriveConnector,
|
||||
DocumentSource.BOOKSTACK: BookstackConnector,
|
||||
DocumentSource.OUTLINE: OutlineConnector,
|
||||
DocumentSource.CONFLUENCE: ConfluenceConnector,
|
||||
DocumentSource.JIRA: JiraConnector,
|
||||
DocumentSource.PRODUCTBOARD: ProductboardConnector,
|
||||
DocumentSource.SLAB: SlabConnector,
|
||||
DocumentSource.NOTION: NotionConnector,
|
||||
DocumentSource.ZULIP: ZulipConnector,
|
||||
DocumentSource.GURU: GuruConnector,
|
||||
DocumentSource.LINEAR: LinearConnector,
|
||||
DocumentSource.HUBSPOT: HubSpotConnector,
|
||||
DocumentSource.DOCUMENT360: Document360Connector,
|
||||
DocumentSource.GONG: GongConnector,
|
||||
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
|
||||
DocumentSource.ZENDESK: ZendeskConnector,
|
||||
DocumentSource.LOOPIO: LoopioConnector,
|
||||
DocumentSource.DROPBOX: DropboxConnector,
|
||||
DocumentSource.SHAREPOINT: SharepointConnector,
|
||||
DocumentSource.TEAMS: TeamsConnector,
|
||||
DocumentSource.SALESFORCE: SalesforceConnector,
|
||||
DocumentSource.DISCOURSE: DiscourseConnector,
|
||||
DocumentSource.AXERO: AxeroConnector,
|
||||
DocumentSource.CLICKUP: ClickupConnector,
|
||||
DocumentSource.MEDIAWIKI: MediaWikiConnector,
|
||||
DocumentSource.WIKIPEDIA: WikipediaConnector,
|
||||
DocumentSource.ASANA: AsanaConnector,
|
||||
DocumentSource.S3: BlobStorageConnector,
|
||||
DocumentSource.R2: BlobStorageConnector,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.DISCORD: DiscordConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
DocumentSource.HIGHSPOT: HighspotConnector,
|
||||
DocumentSource.IMAP: ImapConnector,
|
||||
DocumentSource.BITBUCKET: BitbucketConnector,
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: MockConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
# Load the connector class using lazy loading
|
||||
connector = _load_connector_class(source)
|
||||
|
||||
if isinstance(connector_by_source, dict):
|
||||
if input_type is None:
|
||||
# If not specified, default to most exhaustive update
|
||||
connector = connector_by_source.get(InputType.LOAD_STATE)
|
||||
else:
|
||||
connector = connector_by_source.get(input_type)
|
||||
else:
|
||||
connector = connector_by_source
|
||||
if connector is None:
|
||||
raise ConnectorMissingException(f"Connector not found for source={source}")
|
||||
# Validate connector supports the requested input_type
|
||||
_validate_connector_supports_input_type(connector, input_type, source)
|
||||
|
||||
if any(
|
||||
[
|
||||
(
|
||||
input_type == InputType.LOAD_STATE
|
||||
and not issubclass(connector, LoadConnector)
|
||||
),
|
||||
(
|
||||
input_type == InputType.POLL
|
||||
# either poll or checkpoint works for this, in the future
|
||||
# all connectors should be checkpoint connectors
|
||||
and (
|
||||
not issubclass(connector, PollConnector)
|
||||
and not issubclass(connector, CheckpointedConnector)
|
||||
)
|
||||
),
|
||||
(
|
||||
input_type == InputType.EVENT
|
||||
and not issubclass(connector, EventConnector)
|
||||
),
|
||||
]
|
||||
):
|
||||
raise ConnectorMissingException(
|
||||
f"Connector for source={source} does not accept input_type={input_type}"
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
|
||||
@@ -219,12 +219,19 @@ def _get_batch_rate_limited(
|
||||
|
||||
|
||||
def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
||||
def _safe_get(attr_name: str) -> str | None:
|
||||
try:
|
||||
return cast(str | None, getattr(user, attr_name))
|
||||
except GithubException:
|
||||
logger.debug(f"Error getting {attr_name} for user")
|
||||
return None
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in {
|
||||
"login": user.login,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
"login": _safe_get("login"),
|
||||
"name": _safe_get("name"),
|
||||
"email": _safe_get("email"),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
@@ -232,7 +232,7 @@ def thread_to_document(
|
||||
)
|
||||
|
||||
|
||||
class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -397,10 +397,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
|
||||
except HttpError as e:
|
||||
if _is_mail_service_disabled_error(e):
|
||||
logger.warning(
|
||||
@@ -431,7 +431,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -64,7 +64,7 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
@@ -153,7 +153,7 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnector, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1137,7 +1137,9 @@ class GoogleDriveConnector(
|
||||
convert_func,
|
||||
(
|
||||
[file.user_email, self.primary_admin_email]
|
||||
+ get_file_owners(file.drive_file),
|
||||
+ get_file_owners(
|
||||
file.drive_file, self.primary_admin_email
|
||||
),
|
||||
file.drive_file,
|
||||
),
|
||||
)
|
||||
@@ -1294,7 +1296,7 @@ class GoogleDriveConnector(
|
||||
callback.progress("_extract_slim_docs_from_google_drive", 1)
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -97,14 +97,15 @@ def _execute_with_retry(request: Any) -> Any:
|
||||
raise Exception(f"Failed to execute request after {max_attempts} attempts")
|
||||
|
||||
|
||||
def get_file_owners(file: GoogleDriveFileType) -> list[str]:
|
||||
def get_file_owners(file: GoogleDriveFileType, primary_admin_email: str) -> list[str]:
|
||||
"""
|
||||
Get the owners of a file if the attribute is present.
|
||||
"""
|
||||
return [
|
||||
owner.get("emailAddress")
|
||||
email
|
||||
for owner in file.get("owners", [])
|
||||
if owner.get("emailAddress")
|
||||
if (email := owner.get("emailAddress"))
|
||||
and email.split("@")[-1] == primary_admin_email.split("@")[-1]
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -38,7 +38,7 @@ class HighspotSpot(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""
|
||||
Connector for loading data from Highspot.
|
||||
|
||||
@@ -362,7 +362,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
description = item_details.get("description", "")
|
||||
return title, description
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import requests
|
||||
from hubspot import HubSpot # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.hubspot.rate_limit import HubSpotRateLimiter
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -25,6 +29,10 @@ HUBSPOT_API_URL = "https://api.hubapi.com/integrations/v1/me"
|
||||
# Available HubSpot object types
|
||||
AVAILABLE_OBJECT_TYPES = {"tickets", "companies", "deals", "contacts"}
|
||||
|
||||
HUBSPOT_PAGE_SIZE = 100
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -38,6 +46,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self._access_token = access_token
|
||||
self._portal_id: str | None = None
|
||||
self._rate_limiter = HubSpotRateLimiter()
|
||||
|
||||
# Set object types to fetch, default to all available types
|
||||
if object_types is None:
|
||||
@@ -77,6 +86,37 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
"""Set the portal ID."""
|
||||
self._portal_id = value
|
||||
|
||||
def _call_hubspot(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||
return self._rate_limiter.call(func, *args, **kwargs)
|
||||
|
||||
def _paginated_results(
|
||||
self,
|
||||
fetch_page: Callable[..., Any],
|
||||
**kwargs: Any,
|
||||
) -> Generator[Any, None, None]:
|
||||
base_kwargs = dict(kwargs)
|
||||
base_kwargs.setdefault("limit", HUBSPOT_PAGE_SIZE)
|
||||
|
||||
after: str | None = None
|
||||
while True:
|
||||
page_kwargs = base_kwargs.copy()
|
||||
if after is not None:
|
||||
page_kwargs["after"] = after
|
||||
|
||||
page = self._call_hubspot(fetch_page, **page_kwargs)
|
||||
results = getattr(page, "results", [])
|
||||
for result in results:
|
||||
yield result
|
||||
|
||||
paging = getattr(page, "paging", None)
|
||||
next_page = getattr(paging, "next", None) if paging else None
|
||||
if next_page is None:
|
||||
break
|
||||
|
||||
after = getattr(next_page, "after", None)
|
||||
if after is None:
|
||||
break
|
||||
|
||||
def _clean_html_content(self, html_content: str) -> str:
|
||||
"""Clean HTML content and extract raw text"""
|
||||
if not html_content:
|
||||
@@ -150,78 +190,82 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get associated objects for a given object"""
|
||||
try:
|
||||
associations = api_client.crm.associations.v4.basic_api.get_page(
|
||||
associations_iter = self._paginated_results(
|
||||
api_client.crm.associations.v4.basic_api.get_page,
|
||||
object_type=from_object_type,
|
||||
object_id=object_id,
|
||||
to_object_type=to_object_type,
|
||||
)
|
||||
|
||||
associated_objects = []
|
||||
if associations.results:
|
||||
object_ids = [assoc.to_object_id for assoc in associations.results]
|
||||
object_ids = [assoc.to_object_id for assoc in associations_iter]
|
||||
|
||||
# Batch get the associated objects
|
||||
if to_object_type == "contacts":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.contacts.basic_api.get_by_id(
|
||||
contact_id=obj_id,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
"email",
|
||||
"company",
|
||||
"jobtitle",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
|
||||
associated_objects: list[dict[str, Any]] = []
|
||||
|
||||
elif to_object_type == "companies":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.companies.basic_api.get_by_id(
|
||||
company_id=obj_id,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
"industry",
|
||||
"city",
|
||||
"state",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch company {obj_id}: {e}")
|
||||
if to_object_type == "contacts":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.contacts.basic_api.get_by_id,
|
||||
contact_id=obj_id,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
"email",
|
||||
"company",
|
||||
"jobtitle",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "deals":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.deals.basic_api.get_by_id(
|
||||
deal_id=obj_id,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
"dealstage",
|
||||
"closedate",
|
||||
"pipeline",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
|
||||
elif to_object_type == "companies":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.companies.basic_api.get_by_id,
|
||||
company_id=obj_id,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
"industry",
|
||||
"city",
|
||||
"state",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch company {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "tickets":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.tickets.basic_api.get_by_id(
|
||||
ticket_id=obj_id,
|
||||
properties=["subject", "content", "hs_ticket_priority"],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
|
||||
elif to_object_type == "deals":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.deals.basic_api.get_by_id,
|
||||
deal_id=obj_id,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
"dealstage",
|
||||
"closedate",
|
||||
"pipeline",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "tickets":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.tickets.basic_api.get_by_id,
|
||||
ticket_id=obj_id,
|
||||
properties=["subject", "content", "hs_ticket_priority"],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
|
||||
|
||||
return associated_objects
|
||||
|
||||
@@ -239,33 +283,33 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get notes associated with a given object"""
|
||||
try:
|
||||
# Get associations to notes (engagement type)
|
||||
associations = api_client.crm.associations.v4.basic_api.get_page(
|
||||
associations_iter = self._paginated_results(
|
||||
api_client.crm.associations.v4.basic_api.get_page,
|
||||
object_type=object_type,
|
||||
object_id=object_id,
|
||||
to_object_type="notes",
|
||||
)
|
||||
|
||||
associated_notes = []
|
||||
if associations.results:
|
||||
note_ids = [assoc.to_object_id for assoc in associations.results]
|
||||
note_ids = [assoc.to_object_id for assoc in associations_iter]
|
||||
|
||||
# Batch get the associated notes
|
||||
for note_id in note_ids:
|
||||
try:
|
||||
# Notes are engagements in HubSpot, use the engagements API
|
||||
note = api_client.crm.objects.notes.basic_api.get_by_id(
|
||||
note_id=note_id,
|
||||
properties=[
|
||||
"hs_note_body",
|
||||
"hs_timestamp",
|
||||
"hs_created_by",
|
||||
"hubspot_owner_id",
|
||||
],
|
||||
)
|
||||
associated_notes.append(note.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch note {note_id}: {e}")
|
||||
associated_notes = []
|
||||
|
||||
for note_id in note_ids:
|
||||
try:
|
||||
# Notes are engagements in HubSpot, use the engagements API
|
||||
note = self._call_hubspot(
|
||||
api_client.crm.objects.notes.basic_api.get_by_id,
|
||||
note_id=note_id,
|
||||
properties=[
|
||||
"hs_note_body",
|
||||
"hs_timestamp",
|
||||
"hs_created_by",
|
||||
"hubspot_owner_id",
|
||||
],
|
||||
)
|
||||
associated_notes.append(note.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch note {note_id}: {e}")
|
||||
|
||||
return associated_notes
|
||||
|
||||
@@ -358,7 +402,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_tickets = api_client.crm.tickets.get_all(
|
||||
|
||||
tickets_iter = self._paginated_results(
|
||||
api_client.crm.tickets.basic_api.get_page,
|
||||
properties=[
|
||||
"subject",
|
||||
"content",
|
||||
@@ -371,7 +417,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for ticket in all_tickets:
|
||||
for ticket in tickets_iter:
|
||||
updated_at = ticket.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -459,7 +505,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_companies = api_client.crm.companies.get_all(
|
||||
|
||||
companies_iter = self._paginated_results(
|
||||
api_client.crm.companies.basic_api.get_page,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
@@ -475,7 +523,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for company in all_companies:
|
||||
for company in companies_iter:
|
||||
updated_at = company.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -582,7 +630,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_deals = api_client.crm.deals.get_all(
|
||||
|
||||
deals_iter = self._paginated_results(
|
||||
api_client.crm.deals.basic_api.get_page,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
@@ -598,7 +648,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for deal in all_deals:
|
||||
for deal in deals_iter:
|
||||
updated_at = deal.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -703,7 +753,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_contacts = api_client.crm.contacts.get_all(
|
||||
|
||||
contacts_iter = self._paginated_results(
|
||||
api_client.crm.contacts.basic_api.get_page,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
@@ -721,7 +773,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for contact in all_contacts:
|
||||
for contact in contacts_iter:
|
||||
updated_at = contact.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
|
||||
145
backend/onyx/connectors/hubspot/rate_limit.py
Normal file
145
backend/onyx/connectors/hubspot/rate_limit.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
RateLimitTriedTooManyTimesError,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# HubSpot exposes a ten second rolling window (x-hubspot-ratelimit-interval-milliseconds)
|
||||
# with a maximum of 190 requests, and a per-second limit of 19 requests.
|
||||
_HUBSPOT_TEN_SECOND_LIMIT = 190
|
||||
_HUBSPOT_TEN_SECOND_PERIOD = 10 # seconds
|
||||
_HUBSPOT_SECONDLY_LIMIT = 19
|
||||
_HUBSPOT_SECONDLY_PERIOD = 1 # second
|
||||
_DEFAULT_SLEEP_SECONDS = 10
|
||||
_SLEEP_PADDING_SECONDS = 1.0
|
||||
_MAX_RATE_LIMIT_RETRIES = 5
|
||||
|
||||
|
||||
def _extract_header(headers: Any, key: str) -> str | None:
|
||||
if headers is None:
|
||||
return None
|
||||
|
||||
getter = getattr(headers, "get", None)
|
||||
if callable(getter):
|
||||
value = getter(key)
|
||||
if value is not None:
|
||||
return value
|
||||
|
||||
if isinstance(headers, dict):
|
||||
value = headers.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_rate_limit_error(exception: Exception) -> bool:
|
||||
status = getattr(exception, "status", None)
|
||||
if status == 429:
|
||||
return True
|
||||
|
||||
headers = getattr(exception, "headers", None)
|
||||
if headers is not None:
|
||||
remaining = _extract_header(headers, "x-hubspot-ratelimit-remaining")
|
||||
if remaining == "0":
|
||||
return True
|
||||
secondly_remaining = _extract_header(
|
||||
headers, "x-hubspot-ratelimit-secondly-remaining"
|
||||
)
|
||||
if secondly_remaining == "0":
|
||||
return True
|
||||
|
||||
message = str(exception)
|
||||
return "RATE_LIMIT" in message or "Too Many Requests" in message
|
||||
|
||||
|
||||
def get_rate_limit_retry_delay_seconds(exception: Exception) -> float:
|
||||
headers = getattr(exception, "headers", None)
|
||||
|
||||
retry_after = _extract_header(headers, "Retry-After")
|
||||
if retry_after:
|
||||
try:
|
||||
return float(retry_after) + _SLEEP_PADDING_SECONDS
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Failed to parse Retry-After header '%s' as float", retry_after
|
||||
)
|
||||
|
||||
interval_ms = _extract_header(headers, "x-hubspot-ratelimit-interval-milliseconds")
|
||||
if interval_ms:
|
||||
try:
|
||||
return float(interval_ms) / 1000.0 + _SLEEP_PADDING_SECONDS
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Failed to parse x-hubspot-ratelimit-interval-milliseconds '%s' as float",
|
||||
interval_ms,
|
||||
)
|
||||
|
||||
secondly_limit = _extract_header(headers, "x-hubspot-ratelimit-secondly")
|
||||
if secondly_limit:
|
||||
try:
|
||||
per_second = max(float(secondly_limit), 1.0)
|
||||
return (1.0 / per_second) + _SLEEP_PADDING_SECONDS
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Failed to parse x-hubspot-ratelimit-secondly '%s' as float",
|
||||
secondly_limit,
|
||||
)
|
||||
|
||||
return _DEFAULT_SLEEP_SECONDS + _SLEEP_PADDING_SECONDS
|
||||
|
||||
|
||||
class HubSpotRateLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ten_second_limit: int = _HUBSPOT_TEN_SECOND_LIMIT,
|
||||
ten_second_period: int = _HUBSPOT_TEN_SECOND_PERIOD,
|
||||
secondly_limit: int = _HUBSPOT_SECONDLY_LIMIT,
|
||||
secondly_period: int = _HUBSPOT_SECONDLY_PERIOD,
|
||||
max_retries: int = _MAX_RATE_LIMIT_RETRIES,
|
||||
) -> None:
|
||||
self._max_retries = max_retries
|
||||
|
||||
@rate_limit_builder(max_calls=secondly_limit, period=secondly_period)
|
||||
@rate_limit_builder(max_calls=ten_second_limit, period=ten_second_period)
|
||||
def _execute(callable_: Callable[[], T]) -> T:
|
||||
return callable_()
|
||||
|
||||
self._execute = _execute
|
||||
|
||||
def call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||
attempts = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
return self._execute(lambda: func(*args, **kwargs))
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
if not is_rate_limit_error(exc):
|
||||
raise
|
||||
|
||||
attempts += 1
|
||||
if attempts > self._max_retries:
|
||||
raise RateLimitTriedTooManyTimesError(
|
||||
"Exceeded configured HubSpot rate limit retries"
|
||||
) from exc
|
||||
|
||||
wait_time = get_rate_limit_retry_delay_seconds(exc)
|
||||
logger.notice(
|
||||
"HubSpot rate limit reached. Sleeping %.2f seconds before retrying.",
|
||||
wait_time,
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
@@ -97,11 +97,20 @@ class PollConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Slim connectors can retrieve just the ids and
|
||||
# permission syncing information for connected documents
|
||||
# Slim connectors retrieve just the ids of documents
|
||||
class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Slim connectors retrieve both the ids AND
|
||||
# permission syncing information for connected documents
|
||||
class SlimConnectorWithPermSync(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -25,11 +25,11 @@ from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.jira.access import get_project_permissions
|
||||
from onyx.connectors.jira.utils import best_effort_basic_expert_info
|
||||
from onyx.connectors.jira.utils import best_effort_get_field_from_issue
|
||||
@@ -247,7 +247,7 @@ def _perform_jql_search_v2(
|
||||
|
||||
|
||||
def process_jira_issue(
|
||||
jira_client: JIRA,
|
||||
jira_base_url: str,
|
||||
issue: Issue,
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
@@ -281,7 +281,7 @@ def process_jira_issue(
|
||||
)
|
||||
return None
|
||||
|
||||
page_url = build_jira_url(jira_client, issue.key)
|
||||
page_url = build_jira_url(jira_base_url, issue.key)
|
||||
|
||||
metadata_dict: dict[str, str | list[str]] = {}
|
||||
people = set()
|
||||
@@ -359,7 +359,10 @@ class JiraConnectorCheckpoint(ConnectorCheckpoint):
|
||||
offset: int | None = None
|
||||
|
||||
|
||||
class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnector):
|
||||
class JiraConnector(
|
||||
CheckpointedConnectorWithPermSync[JiraConnectorCheckpoint],
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
jira_base_url: str,
|
||||
@@ -372,15 +375,23 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
|
||||
# Custom JQL query to filter Jira issues
|
||||
jql_query: str | None = None,
|
||||
scoped_token: bool = False,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
# dealing with scoped tokens is a bit tricky becasue we need to hit api.atlassian.net
|
||||
# when making jira requests but still want correct links to issues in the UI.
|
||||
# So, the user's base url is stored here, but converted to a scoped url when passed
|
||||
# to the jira client.
|
||||
self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present
|
||||
self.jira_project = project_key
|
||||
self._comment_email_blacklist = comment_email_blacklist or []
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.jql_query = jql_query
|
||||
|
||||
self.scoped_token = scoped_token
|
||||
self._jira_client: JIRA | None = None
|
||||
# Cache project permissions to avoid fetching them repeatedly across runs
|
||||
self._project_permissions_cache: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def comment_email_blacklist(self) -> tuple:
|
||||
@@ -399,10 +410,26 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
return ""
|
||||
return f'"{self.jira_project}"'
|
||||
|
||||
def _get_project_permissions(self, project_key: str) -> Any:
|
||||
"""Get project permissions with caching.
|
||||
|
||||
Args:
|
||||
project_key: The Jira project key
|
||||
|
||||
Returns:
|
||||
The external access permissions for the project
|
||||
"""
|
||||
if project_key not in self._project_permissions_cache:
|
||||
self._project_permissions_cache[project_key] = get_project_permissions(
|
||||
jira_client=self.jira_client, jira_project=project_key
|
||||
)
|
||||
return self._project_permissions_cache[project_key]
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._jira_client = build_jira_client(
|
||||
credentials=credentials,
|
||||
jira_base=self.jira_base,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -442,15 +469,37 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
jql = self._get_jql_query(start, end)
|
||||
try:
|
||||
return self._load_from_checkpoint(jql, checkpoint)
|
||||
return self._load_from_checkpoint(
|
||||
jql, checkpoint, include_permissions=False
|
||||
)
|
||||
except Exception as e:
|
||||
if is_atlassian_date_error(e):
|
||||
jql = self._get_jql_query(start - ONE_HOUR, end)
|
||||
return self._load_from_checkpoint(jql, checkpoint)
|
||||
return self._load_from_checkpoint(
|
||||
jql, checkpoint, include_permissions=False
|
||||
)
|
||||
raise e
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: JiraConnectorCheckpoint,
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
"""Load documents from checkpoint with permission information included."""
|
||||
jql = self._get_jql_query(start, end)
|
||||
try:
|
||||
return self._load_from_checkpoint(jql, checkpoint, include_permissions=True)
|
||||
except Exception as e:
|
||||
if is_atlassian_date_error(e):
|
||||
jql = self._get_jql_query(start - ONE_HOUR, end)
|
||||
return self._load_from_checkpoint(
|
||||
jql, checkpoint, include_permissions=True
|
||||
)
|
||||
raise e
|
||||
|
||||
def _load_from_checkpoint(
|
||||
self, jql: str, checkpoint: JiraConnectorCheckpoint
|
||||
self, jql: str, checkpoint: JiraConnectorCheckpoint, include_permissions: bool
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
# Get the current offset from checkpoint or start at 0
|
||||
starting_offset = checkpoint.offset or 0
|
||||
@@ -472,18 +521,25 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
issue_key = issue.key
|
||||
try:
|
||||
if document := process_jira_issue(
|
||||
jira_client=self.jira_client,
|
||||
jira_base_url=self.jira_base,
|
||||
issue=issue,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
# Add permission information to the document if requested
|
||||
if include_permissions:
|
||||
project_key = get_jira_project_key_from_issue(issue=issue)
|
||||
if project_key:
|
||||
document.external_access = self._get_project_permissions(
|
||||
project_key
|
||||
)
|
||||
yield document
|
||||
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=issue_key,
|
||||
document_link=build_jira_url(self.jira_client, issue_key),
|
||||
document_link=build_jira_url(self.jira_base, issue_key),
|
||||
),
|
||||
failure_message=f"Failed to process Jira issue: {str(e)}",
|
||||
exception=e,
|
||||
@@ -515,7 +571,7 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
# if we didn't retrieve a full batch, we're done
|
||||
checkpoint.has_more = current_offset - starting_offset == page_size
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -534,6 +590,7 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
prev_offset = 0
|
||||
current_offset = 0
|
||||
slim_doc_batch = []
|
||||
|
||||
while checkpoint.has_more:
|
||||
for issue in _perform_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
@@ -550,13 +607,12 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
continue
|
||||
|
||||
issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY)
|
||||
id = build_jira_url(self.jira_client, issue_key)
|
||||
id = build_jira_url(self.jira_base, issue_key)
|
||||
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=id,
|
||||
external_access=get_project_permissions(
|
||||
jira_client=self.jira_client, jira_project=project_key
|
||||
),
|
||||
external_access=self._get_project_permissions(project_key),
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
@@ -701,7 +757,7 @@ if __name__ == "__main__":
|
||||
start = 0
|
||||
end = datetime.now().timestamp()
|
||||
|
||||
for slim_doc in connector.retrieve_all_slim_documents(
|
||||
for slim_doc in connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
|
||||
@@ -10,6 +10,7 @@ from jira.resources import CustomFieldOption
|
||||
from jira.resources import Issue
|
||||
from jira.resources import User
|
||||
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -74,11 +75,18 @@ def extract_text_from_adf(adf: dict | None) -> str:
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
|
||||
return f"{jira_client.client_info()}/browse/{issue_key}"
|
||||
def build_jira_url(jira_base_url: str, issue_key: str) -> str:
|
||||
"""
|
||||
Get the url used to access an issue in the UI.
|
||||
"""
|
||||
return f"{jira_base_url}/browse/{issue_key}"
|
||||
|
||||
|
||||
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
|
||||
def build_jira_client(
|
||||
credentials: dict[str, Any], jira_base: str, scoped_token: bool = False
|
||||
) -> JIRA:
|
||||
|
||||
jira_base = scoped_url(jira_base, "jira") if scoped_token else jira_base
|
||||
api_token = credentials["jira_api_token"]
|
||||
# if user provide an email we assume it's cloud
|
||||
if "jira_user_email" in credentials:
|
||||
|
||||
208
backend/onyx/connectors/registry.py
Normal file
208
backend/onyx/connectors/registry.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Registry mapping for connector classes."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
|
||||
class ConnectorMapping(BaseModel):
|
||||
module_path: str
|
||||
class_name: str
|
||||
|
||||
|
||||
# Mapping of DocumentSource to connector details for lazy loading
|
||||
CONNECTOR_CLASS_MAP = {
|
||||
DocumentSource.WEB: ConnectorMapping(
|
||||
module_path="onyx.connectors.web.connector",
|
||||
class_name="WebConnector",
|
||||
),
|
||||
DocumentSource.FILE: ConnectorMapping(
|
||||
module_path="onyx.connectors.file.connector",
|
||||
class_name="LocalFileConnector",
|
||||
),
|
||||
DocumentSource.SLACK: ConnectorMapping(
|
||||
module_path="onyx.connectors.slack.connector",
|
||||
class_name="SlackConnector",
|
||||
),
|
||||
DocumentSource.GITHUB: ConnectorMapping(
|
||||
module_path="onyx.connectors.github.connector",
|
||||
class_name="GithubConnector",
|
||||
),
|
||||
DocumentSource.GMAIL: ConnectorMapping(
|
||||
module_path="onyx.connectors.gmail.connector",
|
||||
class_name="GmailConnector",
|
||||
),
|
||||
DocumentSource.GITLAB: ConnectorMapping(
|
||||
module_path="onyx.connectors.gitlab.connector",
|
||||
class_name="GitlabConnector",
|
||||
),
|
||||
DocumentSource.GITBOOK: ConnectorMapping(
|
||||
module_path="onyx.connectors.gitbook.connector",
|
||||
class_name="GitbookConnector",
|
||||
),
|
||||
DocumentSource.GOOGLE_DRIVE: ConnectorMapping(
|
||||
module_path="onyx.connectors.google_drive.connector",
|
||||
class_name="GoogleDriveConnector",
|
||||
),
|
||||
DocumentSource.BOOKSTACK: ConnectorMapping(
|
||||
module_path="onyx.connectors.bookstack.connector",
|
||||
class_name="BookstackConnector",
|
||||
),
|
||||
DocumentSource.OUTLINE: ConnectorMapping(
|
||||
module_path="onyx.connectors.outline.connector",
|
||||
class_name="OutlineConnector",
|
||||
),
|
||||
DocumentSource.CONFLUENCE: ConnectorMapping(
|
||||
module_path="onyx.connectors.confluence.connector",
|
||||
class_name="ConfluenceConnector",
|
||||
),
|
||||
DocumentSource.JIRA: ConnectorMapping(
|
||||
module_path="onyx.connectors.jira.connector",
|
||||
class_name="JiraConnector",
|
||||
),
|
||||
DocumentSource.PRODUCTBOARD: ConnectorMapping(
|
||||
module_path="onyx.connectors.productboard.connector",
|
||||
class_name="ProductboardConnector",
|
||||
),
|
||||
DocumentSource.SLAB: ConnectorMapping(
|
||||
module_path="onyx.connectors.slab.connector",
|
||||
class_name="SlabConnector",
|
||||
),
|
||||
DocumentSource.NOTION: ConnectorMapping(
|
||||
module_path="onyx.connectors.notion.connector",
|
||||
class_name="NotionConnector",
|
||||
),
|
||||
DocumentSource.ZULIP: ConnectorMapping(
|
||||
module_path="onyx.connectors.zulip.connector",
|
||||
class_name="ZulipConnector",
|
||||
),
|
||||
DocumentSource.GURU: ConnectorMapping(
|
||||
module_path="onyx.connectors.guru.connector",
|
||||
class_name="GuruConnector",
|
||||
),
|
||||
DocumentSource.LINEAR: ConnectorMapping(
|
||||
module_path="onyx.connectors.linear.connector",
|
||||
class_name="LinearConnector",
|
||||
),
|
||||
DocumentSource.HUBSPOT: ConnectorMapping(
|
||||
module_path="onyx.connectors.hubspot.connector",
|
||||
class_name="HubSpotConnector",
|
||||
),
|
||||
DocumentSource.DOCUMENT360: ConnectorMapping(
|
||||
module_path="onyx.connectors.document360.connector",
|
||||
class_name="Document360Connector",
|
||||
),
|
||||
DocumentSource.GONG: ConnectorMapping(
|
||||
module_path="onyx.connectors.gong.connector",
|
||||
class_name="GongConnector",
|
||||
),
|
||||
DocumentSource.GOOGLE_SITES: ConnectorMapping(
|
||||
module_path="onyx.connectors.google_site.connector",
|
||||
class_name="GoogleSitesConnector",
|
||||
),
|
||||
DocumentSource.ZENDESK: ConnectorMapping(
|
||||
module_path="onyx.connectors.zendesk.connector",
|
||||
class_name="ZendeskConnector",
|
||||
),
|
||||
DocumentSource.LOOPIO: ConnectorMapping(
|
||||
module_path="onyx.connectors.loopio.connector",
|
||||
class_name="LoopioConnector",
|
||||
),
|
||||
DocumentSource.DROPBOX: ConnectorMapping(
|
||||
module_path="onyx.connectors.dropbox.connector",
|
||||
class_name="DropboxConnector",
|
||||
),
|
||||
DocumentSource.SHAREPOINT: ConnectorMapping(
|
||||
module_path="onyx.connectors.sharepoint.connector",
|
||||
class_name="SharepointConnector",
|
||||
),
|
||||
DocumentSource.TEAMS: ConnectorMapping(
|
||||
module_path="onyx.connectors.teams.connector",
|
||||
class_name="TeamsConnector",
|
||||
),
|
||||
DocumentSource.SALESFORCE: ConnectorMapping(
|
||||
module_path="onyx.connectors.salesforce.connector",
|
||||
class_name="SalesforceConnector",
|
||||
),
|
||||
DocumentSource.DISCOURSE: ConnectorMapping(
|
||||
module_path="onyx.connectors.discourse.connector",
|
||||
class_name="DiscourseConnector",
|
||||
),
|
||||
DocumentSource.AXERO: ConnectorMapping(
|
||||
module_path="onyx.connectors.axero.connector",
|
||||
class_name="AxeroConnector",
|
||||
),
|
||||
DocumentSource.CLICKUP: ConnectorMapping(
|
||||
module_path="onyx.connectors.clickup.connector",
|
||||
class_name="ClickupConnector",
|
||||
),
|
||||
DocumentSource.MEDIAWIKI: ConnectorMapping(
|
||||
module_path="onyx.connectors.mediawiki.wiki",
|
||||
class_name="MediaWikiConnector",
|
||||
),
|
||||
DocumentSource.WIKIPEDIA: ConnectorMapping(
|
||||
module_path="onyx.connectors.wikipedia.connector",
|
||||
class_name="WikipediaConnector",
|
||||
),
|
||||
DocumentSource.ASANA: ConnectorMapping(
|
||||
module_path="onyx.connectors.asana.connector",
|
||||
class_name="AsanaConnector",
|
||||
),
|
||||
DocumentSource.S3: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.R2: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.OCI_STORAGE: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.XENFORO: ConnectorMapping(
|
||||
module_path="onyx.connectors.xenforo.connector",
|
||||
class_name="XenforoConnector",
|
||||
),
|
||||
DocumentSource.DISCORD: ConnectorMapping(
|
||||
module_path="onyx.connectors.discord.connector",
|
||||
class_name="DiscordConnector",
|
||||
),
|
||||
DocumentSource.FRESHDESK: ConnectorMapping(
|
||||
module_path="onyx.connectors.freshdesk.connector",
|
||||
class_name="FreshdeskConnector",
|
||||
),
|
||||
DocumentSource.FIREFLIES: ConnectorMapping(
|
||||
module_path="onyx.connectors.fireflies.connector",
|
||||
class_name="FirefliesConnector",
|
||||
),
|
||||
DocumentSource.EGNYTE: ConnectorMapping(
|
||||
module_path="onyx.connectors.egnyte.connector",
|
||||
class_name="EgnyteConnector",
|
||||
),
|
||||
DocumentSource.AIRTABLE: ConnectorMapping(
|
||||
module_path="onyx.connectors.airtable.airtable_connector",
|
||||
class_name="AirtableConnector",
|
||||
),
|
||||
DocumentSource.HIGHSPOT: ConnectorMapping(
|
||||
module_path="onyx.connectors.highspot.connector",
|
||||
class_name="HighspotConnector",
|
||||
),
|
||||
DocumentSource.IMAP: ConnectorMapping(
|
||||
module_path="onyx.connectors.imap.connector",
|
||||
class_name="ImapConnector",
|
||||
),
|
||||
DocumentSource.BITBUCKET: ConnectorMapping(
|
||||
module_path="onyx.connectors.bitbucket.connector",
|
||||
class_name="BitbucketConnector",
|
||||
),
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: ConnectorMapping(
|
||||
module_path="onyx.connectors.mock_connector.connector",
|
||||
class_name="MockConnector",
|
||||
),
|
||||
}
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -151,7 +151,7 @@ def _validate_custom_query_config(config: dict[str, Any]) -> None:
|
||||
)
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Approach outline
|
||||
|
||||
Goal
|
||||
@@ -1119,7 +1119,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._delta_sync(temp_dir, start, end)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -41,7 +41,7 @@ from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import IndexingHeartbeatInterface
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
@@ -73,7 +73,8 @@ class SiteDescriptor(BaseModel):
|
||||
"""Data class for storing SharePoint site information.
|
||||
|
||||
Args:
|
||||
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
|
||||
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests
|
||||
or https://danswerai.sharepoint.com/teams/team-name)
|
||||
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
|
||||
If None, all drives will be accessed.
|
||||
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
|
||||
@@ -672,7 +673,7 @@ def _convert_sitepage_to_slim_document(
|
||||
|
||||
|
||||
class SharepointConnector(
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
CheckpointedConnectorWithPermSync[SharepointConnectorCheckpoint],
|
||||
):
|
||||
def __init__(
|
||||
@@ -703,9 +704,11 @@ class SharepointConnector(
|
||||
|
||||
# Ensure sites are sharepoint urls
|
||||
for site_url in self.sites:
|
||||
if not site_url.startswith("https://") or "/sites/" not in site_url:
|
||||
if not site_url.startswith("https://") or not (
|
||||
"/sites/" in site_url or "/teams/" in site_url
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site)"
|
||||
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -720,10 +723,17 @@ class SharepointConnector(
|
||||
site_data_list = []
|
||||
for url in site_urls:
|
||||
parts = url.strip().split("/")
|
||||
|
||||
site_type_index = None
|
||||
if "sites" in parts:
|
||||
sites_index = parts.index("sites")
|
||||
site_url = "/".join(parts[: sites_index + 2])
|
||||
remaining_parts = parts[sites_index + 2 :]
|
||||
site_type_index = parts.index("sites")
|
||||
elif "teams" in parts:
|
||||
site_type_index = parts.index("teams")
|
||||
|
||||
if site_type_index is not None:
|
||||
# Extract the base site URL (up to and including the site/team name)
|
||||
site_url = "/".join(parts[: site_type_index + 2])
|
||||
remaining_parts = parts[site_type_index + 2 :]
|
||||
|
||||
# Extract drive name and folder path
|
||||
if remaining_parts:
|
||||
@@ -745,7 +755,9 @@ class SharepointConnector(
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Site URL '{url}' is not a valid Sharepoint URL")
|
||||
logger.warning(
|
||||
f"Site URL '{url}' is not a valid Sharepoint URL (must contain /sites/ or /teams/)"
|
||||
)
|
||||
return site_data_list
|
||||
|
||||
def _get_drive_items_for_drive_name(
|
||||
@@ -1597,7 +1609,7 @@ class SharepointConnector(
|
||||
) -> SharepointConnectorCheckpoint:
|
||||
return SharepointConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -164,7 +164,7 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
|
||||
return urljoin(urljoin(base_url, "posts/"), url_id)
|
||||
|
||||
|
||||
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
@@ -239,7 +239,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -42,7 +42,7 @@ from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
@@ -581,7 +581,7 @@ def _process_message(
|
||||
|
||||
|
||||
class SlackConnector(
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
CredentialsConnector,
|
||||
CheckpointedConnectorWithPermSync[SlackCheckpoint],
|
||||
):
|
||||
@@ -732,7 +732,7 @@ class SlackConnector(
|
||||
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||
self.credentials_provider = credentials_provider
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -22,7 +22,7 @@ from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -51,7 +51,7 @@ class TeamsCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
class TeamsConnector(
|
||||
CheckpointedConnector[TeamsCheckpoint],
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
MAX_WORKERS = 10
|
||||
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
|
||||
@@ -228,9 +228,9 @@ class TeamsConnector(
|
||||
has_more=bool(todos),
|
||||
)
|
||||
|
||||
# impls for SlimConnector
|
||||
# impls for SlimConnectorWithPermSync
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -572,7 +572,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
teams_connector.validate_connector_settings()
|
||||
|
||||
for slim_doc in teams_connector.retrieve_all_slim_documents():
|
||||
for slim_doc in teams_connector.retrieve_all_slim_docs_perm_sync():
|
||||
...
|
||||
|
||||
for doc in load_everything_from_checkpoint_connector(
|
||||
|
||||
@@ -219,6 +219,25 @@ def is_valid_url(url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _same_site(base_url: str, candidate_url: str) -> bool:
|
||||
base, candidate = urlparse(base_url), urlparse(candidate_url)
|
||||
base_netloc = base.netloc.lower().removeprefix("www.")
|
||||
candidate_netloc = candidate.netloc.lower().removeprefix("www.")
|
||||
if base_netloc != candidate_netloc:
|
||||
return False
|
||||
|
||||
base_path = (base.path or "/").rstrip("/")
|
||||
if base_path in ("", "/"):
|
||||
return True
|
||||
|
||||
candidate_path = candidate.path or "/"
|
||||
if candidate_path == base_path:
|
||||
return True
|
||||
|
||||
boundary = f"{base_path}/"
|
||||
return candidate_path.startswith(boundary)
|
||||
|
||||
|
||||
def get_internal_links(
|
||||
base_url: str, url: str, soup: BeautifulSoup, should_ignore_pound: bool = True
|
||||
) -> set[str]:
|
||||
@@ -239,7 +258,7 @@ def get_internal_links(
|
||||
# Relative path handling
|
||||
href = urljoin(url, href)
|
||||
|
||||
if urlparse(href).netloc == urlparse(url).netloc and base_url in href:
|
||||
if _same_site(base_url, href):
|
||||
internal_links.add(href)
|
||||
return internal_links
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import Document
|
||||
@@ -376,7 +376,7 @@ class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
|
||||
class ZendeskConnector(
|
||||
SlimConnector, CheckpointedConnector[ZendeskConnectorCheckpoint]
|
||||
SlimConnectorWithPermSync, CheckpointedConnector[ZendeskConnectorCheckpoint]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -565,7 +565,7 @@ class ZendeskConnector(
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -62,6 +62,14 @@ class MCPAuthenticationType(str, PyEnum):
|
||||
OAUTH = "OAUTH"
|
||||
|
||||
|
||||
class MCPTransport(str, PyEnum):
|
||||
"""MCP transport types"""
|
||||
|
||||
STDIO = "STDIO" # TODO: currently unsupported, need to add a user guide for setup
|
||||
SSE = "SSE" # Server-Sent Events (deprecated but still used)
|
||||
STREAMABLE_HTTP = "STREAMABLE_HTTP" # Modern HTTP streaming
|
||||
|
||||
|
||||
class MCPAuthenticationPerformer(str, PyEnum):
|
||||
ADMIN = "ADMIN"
|
||||
PER_USER = "PER_USER"
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.models import ChatFolder
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_user_folders(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[ChatFolder]:
|
||||
return db_session.query(ChatFolder).filter(ChatFolder.user_id == user_id).all()
|
||||
|
||||
|
||||
def update_folder_display_priority(
|
||||
user_id: UUID | None,
|
||||
display_priority_map: dict[int, int],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
folders = get_user_folders(user_id=user_id, db_session=db_session)
|
||||
folder_ids = {folder.id for folder in folders}
|
||||
if folder_ids != set(display_priority_map.keys()):
|
||||
raise ValueError("Invalid Folder IDs provided")
|
||||
|
||||
for folder in folders:
|
||||
folder.display_priority = display_priority_map[folder.id]
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_folder_by_id(
|
||||
user_id: UUID | None,
|
||||
folder_id: int,
|
||||
db_session: Session,
|
||||
) -> ChatFolder:
|
||||
folder = (
|
||||
db_session.query(ChatFolder).filter(ChatFolder.id == folder_id).one_or_none()
|
||||
)
|
||||
if not folder:
|
||||
raise ValueError("Folder by specified id does not exist")
|
||||
|
||||
if folder.user_id != user_id:
|
||||
raise PermissionError(f"Folder does not belong to user: {user_id}")
|
||||
|
||||
return folder
|
||||
|
||||
|
||||
def create_folder(
|
||||
user_id: UUID | None, folder_name: str | None, db_session: Session
|
||||
) -> int:
|
||||
new_folder = ChatFolder(
|
||||
user_id=user_id,
|
||||
name=folder_name,
|
||||
)
|
||||
db_session.add(new_folder)
|
||||
db_session.commit()
|
||||
|
||||
return new_folder.id
|
||||
|
||||
|
||||
def rename_folder(
|
||||
user_id: UUID | None, folder_id: int, folder_name: str | None, db_session: Session
|
||||
) -> None:
|
||||
folder = get_folder_by_id(
|
||||
user_id=user_id, folder_id=folder_id, db_session=db_session
|
||||
)
|
||||
|
||||
folder.name = folder_name
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def add_chat_to_folder(
|
||||
user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session
|
||||
) -> None:
|
||||
folder = get_folder_by_id(
|
||||
user_id=user_id, folder_id=folder_id, db_session=db_session
|
||||
)
|
||||
|
||||
chat_session.folder_id = folder.id
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_chat_from_folder(
|
||||
user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session
|
||||
) -> None:
|
||||
folder = get_folder_by_id(
|
||||
user_id=user_id, folder_id=folder_id, db_session=db_session
|
||||
)
|
||||
|
||||
if chat_session.folder_id != folder.id:
|
||||
raise ValueError("The chat session is not in the specified folder.")
|
||||
|
||||
if folder.user_id != user_id:
|
||||
raise ValueError(
|
||||
f"Tried to remove a chat session from a folder that does not below to "
|
||||
f"this user, user id: {user_id}"
|
||||
)
|
||||
|
||||
chat_session.folder_id = None
|
||||
if chat_session in folder.chat_sessions:
|
||||
folder.chat_sessions.remove(chat_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_folder(
|
||||
user_id: UUID | None,
|
||||
folder_id: int,
|
||||
including_chats: bool,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
folder = get_folder_by_id(
|
||||
user_id=user_id, folder_id=folder_id, db_session=db_session
|
||||
)
|
||||
|
||||
# Assuming there will not be a massive number of chats in any given folder
|
||||
if including_chats:
|
||||
for chat_session in folder.chat_sessions:
|
||||
delete_chat_session(
|
||||
user_id=user_id,
|
||||
chat_session_id=chat_session.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.delete(folder)
|
||||
db_session.commit()
|
||||
@@ -112,6 +112,7 @@ def upsert_llm_provider(
|
||||
name=model_configuration.name,
|
||||
is_visible=model_configuration.is_visible,
|
||||
max_input_tokens=model_configuration.max_input_tokens,
|
||||
supports_image_input=model_configuration.supports_image_input,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
)
|
||||
|
||||
@@ -4,8 +4,10 @@ from uuid import UUID
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPTransport
|
||||
from onyx.db.models import MCPAuthenticationType
|
||||
from onyx.db.models import MCPConnectionConfig
|
||||
from onyx.db.models import MCPServer
|
||||
@@ -89,6 +91,8 @@ def create_mcp_server__no_commit(
|
||||
description: str | None,
|
||||
server_url: str,
|
||||
auth_type: MCPAuthenticationType,
|
||||
transport: MCPTransport,
|
||||
auth_performer: MCPAuthenticationPerformer,
|
||||
db_session: Session,
|
||||
admin_connection_config_id: int | None = None,
|
||||
) -> MCPServer:
|
||||
@@ -98,7 +102,9 @@ def create_mcp_server__no_commit(
|
||||
name=name,
|
||||
description=description,
|
||||
server_url=server_url,
|
||||
transport=transport,
|
||||
auth_type=auth_type,
|
||||
auth_performer=auth_performer,
|
||||
admin_connection_config_id=admin_connection_config_id,
|
||||
)
|
||||
db_session.add(new_server)
|
||||
@@ -114,6 +120,8 @@ def update_mcp_server__no_commit(
|
||||
server_url: str | None = None,
|
||||
auth_type: MCPAuthenticationType | None = None,
|
||||
admin_connection_config_id: int | None = None,
|
||||
auth_performer: MCPAuthenticationPerformer | None = None,
|
||||
transport: MCPTransport | None = None,
|
||||
) -> MCPServer:
|
||||
"""Update an existing MCP server"""
|
||||
server = get_mcp_server_by_id(server_id, db_session)
|
||||
@@ -128,6 +136,10 @@ def update_mcp_server__no_commit(
|
||||
server.auth_type = auth_type
|
||||
if admin_connection_config_id is not None:
|
||||
server.admin_connection_config_id = admin_connection_config_id
|
||||
if auth_performer is not None:
|
||||
server.auth_performer = auth_performer
|
||||
if transport is not None:
|
||||
server.transport = transport
|
||||
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
return server
|
||||
@@ -147,18 +159,6 @@ def delete_mcp_server(server_id: int, db_session: Session) -> None:
|
||||
logger.info(f"Successfully deleted MCP server {server_id} and its tools")
|
||||
|
||||
|
||||
# TODO: this is pretty hacky
|
||||
def get_mcp_server_auth_performer(mcp_server: MCPServer) -> MCPAuthenticationPerformer:
|
||||
"""Get the authentication performer for an MCP server"""
|
||||
if mcp_server.auth_type == MCPAuthenticationType.OAUTH:
|
||||
return MCPAuthenticationPerformer.PER_USER
|
||||
if not mcp_server.admin_connection_config:
|
||||
return MCPAuthenticationPerformer.ADMIN
|
||||
if not mcp_server.admin_connection_config.config.get("header_substitutions"):
|
||||
return MCPAuthenticationPerformer.ADMIN
|
||||
return MCPAuthenticationPerformer.PER_USER
|
||||
|
||||
|
||||
def get_all_mcp_tools_for_server(server_id: int, db_session: Session) -> list[Tool]:
|
||||
"""Get all MCP tools for a server"""
|
||||
return list(
|
||||
@@ -259,6 +259,8 @@ def update_connection_config(
|
||||
|
||||
if config_data is not None:
|
||||
config.config = config_data
|
||||
# Force SQLAlchemy to detect the change by marking the field as modified
|
||||
flag_modified(config, "config")
|
||||
|
||||
db_session.commit()
|
||||
return config
|
||||
@@ -295,7 +297,7 @@ def get_server_auth_template(
|
||||
if not server.admin_connection_config_id:
|
||||
return None
|
||||
|
||||
if get_mcp_server_auth_performer(server) == MCPAuthenticationPerformer.ADMIN:
|
||||
if server.auth_performer == MCPAuthenticationPerformer.ADMIN:
|
||||
return None # admin server implies no template
|
||||
return server.admin_connection_config
|
||||
|
||||
|
||||
@@ -63,6 +63,8 @@ from onyx.db.enums import (
|
||||
SyncStatus,
|
||||
MCPAuthenticationType,
|
||||
UserFileStatus,
|
||||
MCPAuthenticationPerformer,
|
||||
MCPTransport,
|
||||
)
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
@@ -2351,6 +2353,8 @@ class ModelConfiguration(Base):
|
||||
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
|
||||
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
llm_provider: Mapped["LLMProvider"] = relationship(
|
||||
"LLMProvider",
|
||||
back_populates="model_configurations",
|
||||
@@ -3468,10 +3472,18 @@ class MCPServer(Base):
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
server_url: Mapped[str] = mapped_column(String, nullable=False)
|
||||
# Transport type for connecting to the MCP server
|
||||
transport: Mapped[MCPTransport] = mapped_column(
|
||||
Enum(MCPTransport, native_enum=False), nullable=False
|
||||
)
|
||||
# Auth type: "none", "api_token", or "oauth"
|
||||
auth_type: Mapped[MCPAuthenticationType] = mapped_column(
|
||||
Enum(MCPAuthenticationType, native_enum=False), nullable=False
|
||||
)
|
||||
# Who performs authentication for this server (ADMIN or PER_USER)
|
||||
auth_performer: Mapped[MCPAuthenticationPerformer] = mapped_column(
|
||||
Enum(MCPAuthenticationPerformer, native_enum=False), nullable=False
|
||||
)
|
||||
# Admin connection config - used for the config page
|
||||
# and (when applicable) admin-managed auth
|
||||
# and (when applicable) per-user auth
|
||||
|
||||
@@ -1,478 +0,0 @@
|
||||
import datetime
|
||||
import time
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import get_current_tenant_id
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import create_connector
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
from onyx.server.documents.connector import trigger_indexing_for_cc_pair
|
||||
from onyx.server.documents.connector import upload_files
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.server.models import StatusResponse
|
||||
|
||||
USER_FILE_CONSTANT = "USER_FILE_CONNECTOR"
|
||||
|
||||
|
||||
def create_user_files(
|
||||
files: List[UploadFile],
|
||||
folder_id: int | None,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
link_url: str | None = None,
|
||||
) -> list[UserFile]:
|
||||
"""NOTE(rkuo): This function can take -1 (RECENT_DOCS_FOLDER_ID for folder_id.
|
||||
Document what this does?
|
||||
"""
|
||||
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(files)
|
||||
user_files = []
|
||||
|
||||
for file_path, file in zip(upload_response.file_paths, files):
|
||||
new_file = UserFile(
|
||||
user_id=user.id if user else None,
|
||||
folder_id=folder_id,
|
||||
file_id=file_path,
|
||||
document_id="USER_FILE_CONNECTOR__" + file_path,
|
||||
name=file.filename,
|
||||
token_count=None,
|
||||
link_url=link_url,
|
||||
content_type=file.content_type,
|
||||
)
|
||||
db_session.add(new_file)
|
||||
user_files.append(new_file)
|
||||
db_session.commit()
|
||||
return user_files
|
||||
|
||||
|
||||
def upload_files_to_user_files_with_indexing(
|
||||
files: List[UploadFile],
|
||||
folder_id: int | None,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
trigger_index: bool = True,
|
||||
) -> list[UserFile]:
|
||||
"""NOTE(rkuo): This function can take -1 (RECENT_DOCS_FOLDER_ID for folder_id.
|
||||
Document what this does?
|
||||
|
||||
Create user files and trigger immediate indexing"""
|
||||
# Create the user files first
|
||||
user_files = create_user_files(files, folder_id, user, db_session)
|
||||
|
||||
# Create connector and credential for each file
|
||||
for user_file in user_files:
|
||||
cc_pair = create_file_connector_credential(user_file, user, db_session)
|
||||
user_file.cc_pair_id = cc_pair.data
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# Trigger immediate high-priority indexing for all created files
|
||||
if trigger_index:
|
||||
tenant_id = get_current_tenant_id()
|
||||
for user_file in user_files:
|
||||
# Use the existing trigger_indexing_for_cc_pair function but with highest priority
|
||||
if user_file.cc_pair_id:
|
||||
trigger_indexing_for_cc_pair(
|
||||
[],
|
||||
user_file.cc_pair.connector_id,
|
||||
False,
|
||||
tenant_id,
|
||||
db_session,
|
||||
is_user_file=True,
|
||||
)
|
||||
|
||||
return user_files
|
||||
|
||||
|
||||
def create_file_connector_credential(
|
||||
user_file: UserFile, user: User, db_session: Session
|
||||
) -> StatusResponse:
|
||||
"""Create connector and credential for a user file"""
|
||||
connector_base = ConnectorBase(
|
||||
name=f"UserFile-{user_file.file_id}-{int(time.time())}",
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [user_file.file_id],
|
||||
"file_names": [user_file.name],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
|
||||
connector = create_connector(db_session=db_session, connector_data=connector_base)
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name=f"UserFileCredential-{user_file.file_id}-{int(time.time())}",
|
||||
is_user_file=True,
|
||||
)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
return add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
cc_pair_name=f"UserFileCCPair-{user_file.file_id}-{int(time.time())}",
|
||||
access_type=AccessType.PRIVATE,
|
||||
auto_sync_options=None,
|
||||
groups=[],
|
||||
is_user_file=True,
|
||||
)
|
||||
|
||||
|
||||
def get_user_file_indexing_status(
|
||||
file_ids: list[int], db_session: Session
|
||||
) -> dict[int, bool]:
|
||||
"""Get indexing status for multiple user files"""
|
||||
status_dict = {}
|
||||
|
||||
# Query UserFile with cc_pair join
|
||||
files_with_pairs = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id.in_(file_ids))
|
||||
.options(joinedload(UserFile.cc_pair))
|
||||
.all()
|
||||
)
|
||||
|
||||
for file in files_with_pairs:
|
||||
if file.cc_pair and file.cc_pair.last_successful_index_time:
|
||||
status_dict[file.id] = True
|
||||
else:
|
||||
status_dict[file.id] = False
|
||||
|
||||
return status_dict
|
||||
|
||||
|
||||
def calculate_user_files_token_count(
|
||||
file_ids: list[int], folder_ids: list[int], db_session: Session
|
||||
) -> int:
|
||||
"""Calculate total token count for specified files and folders"""
|
||||
total_tokens = 0
|
||||
|
||||
# Get tokens from individual files
|
||||
if file_ids:
|
||||
file_tokens = (
|
||||
db_session.query(func.sum(UserFile.token_count))
|
||||
.filter(UserFile.id.in_(file_ids))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
total_tokens += file_tokens
|
||||
|
||||
# Get tokens from folders
|
||||
if folder_ids:
|
||||
folder_files_tokens = (
|
||||
db_session.query(func.sum(UserFile.token_count))
|
||||
.filter(UserFile.folder_id.in_(folder_ids))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
total_tokens += folder_files_tokens
|
||||
|
||||
return total_tokens
|
||||
|
||||
|
||||
def load_all_user_files(
|
||||
file_ids: list[int], folder_ids: list[int], db_session: Session
|
||||
) -> list[UserFile]:
|
||||
"""Load all user files from specified file IDs and folder IDs"""
|
||||
result = []
|
||||
|
||||
# Get individual files
|
||||
if file_ids:
|
||||
files = db_session.query(UserFile).filter(UserFile.id.in_(file_ids)).all()
|
||||
result.extend(files)
|
||||
|
||||
# Get files from folders
|
||||
if folder_ids:
|
||||
folder_files = (
|
||||
db_session.query(UserFile).filter(UserFile.folder_id.in_(folder_ids)).all()
|
||||
)
|
||||
result.extend(folder_files)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_user_files_from_folder(folder_id: int, db_session: Session) -> list[UserFile]:
|
||||
return db_session.query(UserFile).filter(UserFile.folder_id == folder_id).all()
|
||||
|
||||
|
||||
def share_file_with_assistant(
|
||||
file_id: int, assistant_id: int, db_session: Session
|
||||
) -> None:
|
||||
file = db_session.query(UserFile).filter(UserFile.id == file_id).first()
|
||||
assistant = db_session.query(Persona).filter(Persona.id == assistant_id).first()
|
||||
|
||||
if file and assistant:
|
||||
file.assistants.append(assistant)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def unshare_file_with_assistant(
|
||||
file_id: int, assistant_id: int, db_session: Session
|
||||
) -> None:
|
||||
db_session.query(Persona__UserFile).filter(
|
||||
and_(
|
||||
Persona__UserFile.user_file_id == file_id,
|
||||
Persona__UserFile.persona_id == assistant_id,
|
||||
)
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def share_folder_with_assistant(
|
||||
folder_id: int, assistant_id: int, db_session: Session
|
||||
) -> None:
|
||||
folder = db_session.query(UserFolder).filter(UserFolder.id == folder_id).first()
|
||||
assistant = db_session.query(Persona).filter(Persona.id == assistant_id).first()
|
||||
|
||||
if folder and assistant:
|
||||
for file in folder.files:
|
||||
share_file_with_assistant(file.id, assistant_id, db_session)
|
||||
|
||||
|
||||
def unshare_folder_with_assistant(
|
||||
folder_id: int, assistant_id: int, db_session: Session
|
||||
) -> None:
|
||||
folder = db_session.query(UserFolder).filter(UserFolder.id == folder_id).first()
|
||||
|
||||
if folder:
|
||||
for file in folder.files:
|
||||
unshare_file_with_assistant(file.id, assistant_id, db_session)
|
||||
|
||||
|
||||
def fetch_user_files_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, int | None]:
|
||||
"""
|
||||
Fetches user file IDs for the given document IDs.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to fetch user files for
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary mapping document IDs to user file IDs (or None if no user file exists)
|
||||
"""
|
||||
# First, get the document to cc_pair mapping
|
||||
doc_cc_pairs = (
|
||||
db_session.query(Document.id, ConnectorCredentialPair.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.filter(Document.id.in_(document_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get cc_pair to user_file mapping
|
||||
cc_pair_to_user_file = (
|
||||
db_session.query(ConnectorCredentialPair.id, UserFile.id)
|
||||
.join(UserFile, UserFile.cc_pair_id == ConnectorCredentialPair.id)
|
||||
.filter(
|
||||
ConnectorCredentialPair.id.in_(
|
||||
[cc_pair_id for _, cc_pair_id in doc_cc_pairs]
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create mapping from cc_pair_id to user_file_id
|
||||
cc_pair_to_user_file_dict = {
|
||||
cc_pair_id: user_file_id for cc_pair_id, user_file_id in cc_pair_to_user_file
|
||||
}
|
||||
|
||||
# Create the final result mapping document_id to user_file_id
|
||||
result: dict[str, int | None] = {doc_id: None for doc_id in document_ids}
|
||||
for doc_id, cc_pair_id in doc_cc_pairs:
|
||||
if cc_pair_id in cc_pair_to_user_file_dict:
|
||||
result[doc_id] = cc_pair_to_user_file_dict[cc_pair_id]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_user_folders_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, int | None]:
|
||||
"""
|
||||
Fetches user folder IDs for the given document IDs.
|
||||
|
||||
For each document, returns the folder ID that the document's associated user file belongs to.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to fetch user folders for
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary mapping document IDs to user folder IDs (or None if no user folder exists)
|
||||
"""
|
||||
# First, get the document to cc_pair mapping
|
||||
doc_cc_pairs = (
|
||||
db_session.query(Document.id, ConnectorCredentialPair.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.filter(Document.id.in_(document_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get cc_pair to user_file and folder mapping
|
||||
cc_pair_to_folder = (
|
||||
db_session.query(ConnectorCredentialPair.id, UserFile.folder_id)
|
||||
.join(UserFile, UserFile.cc_pair_id == ConnectorCredentialPair.id)
|
||||
.filter(
|
||||
ConnectorCredentialPair.id.in_(
|
||||
[cc_pair_id for _, cc_pair_id in doc_cc_pairs]
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create mapping from cc_pair_id to folder_id
|
||||
cc_pair_to_folder_dict = {
|
||||
cc_pair_id: folder_id for cc_pair_id, folder_id in cc_pair_to_folder
|
||||
}
|
||||
|
||||
# Create the final result mapping document_id to folder_id
|
||||
result: dict[str, int | None] = {doc_id: None for doc_id in document_ids}
|
||||
for doc_id, cc_pair_id in doc_cc_pairs:
|
||||
if cc_pair_id in cc_pair_to_folder_dict:
|
||||
result[doc_id] = cc_pair_to_folder_dict[cc_pair_id]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_user_file_from_id(db_session: Session, user_file_id: int) -> UserFile | None:
|
||||
return db_session.query(UserFile).filter(UserFile.id == user_file_id).first()
|
||||
|
||||
|
||||
# def fetch_user_files_for_documents(
|
||||
# # document_ids: list[str],
|
||||
# # db_session: Session,
|
||||
# # ) -> dict[str, int | None]:
|
||||
# # # Query UserFile objects for the given document_ids
|
||||
# # user_files = (
|
||||
# # db_session.query(UserFile).filter(UserFile.document_id.in_(document_ids)).all()
|
||||
# # )
|
||||
|
||||
# # # Create a dictionary mapping document_ids to UserFile objects
|
||||
# # result: dict[str, int | None] = {doc_id: None for doc_id in document_ids}
|
||||
# # for user_file in user_files:
|
||||
# # result[user_file.document_id] = user_file.id
|
||||
|
||||
# # return result
|
||||
|
||||
|
||||
def upsert_user_folder(
|
||||
db_session: Session,
|
||||
id: int | None = None,
|
||||
user_id: UUID | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
created_at: datetime.datetime | None = None,
|
||||
user: User | None = None,
|
||||
files: list[UserFile] | None = None,
|
||||
assistants: list[Persona] | None = None,
|
||||
) -> UserFolder:
|
||||
if id is not None:
|
||||
user_folder = db_session.query(UserFolder).filter_by(id=id).first()
|
||||
else:
|
||||
user_folder = (
|
||||
db_session.query(UserFolder).filter_by(name=name, user_id=user_id).first()
|
||||
)
|
||||
|
||||
if user_folder:
|
||||
if user_id is not None:
|
||||
user_folder.user_id = user_id
|
||||
if name is not None:
|
||||
user_folder.name = name
|
||||
if description is not None:
|
||||
user_folder.description = description
|
||||
if created_at is not None:
|
||||
user_folder.created_at = created_at
|
||||
if user is not None:
|
||||
user_folder.user = user
|
||||
if files is not None:
|
||||
user_folder.files = files
|
||||
if assistants is not None:
|
||||
user_folder.assistants = assistants
|
||||
else:
|
||||
user_folder = UserFolder(
|
||||
id=id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
description=description,
|
||||
created_at=created_at or datetime.datetime.utcnow(),
|
||||
user=user,
|
||||
files=files or [],
|
||||
assistants=assistants or [],
|
||||
)
|
||||
db_session.add(user_folder)
|
||||
|
||||
db_session.flush()
|
||||
return user_folder
|
||||
|
||||
|
||||
def get_user_folder_by_name(db_session: Session, name: str) -> UserFolder | None:
|
||||
return db_session.query(UserFolder).filter(UserFolder.name == name).first()
|
||||
|
||||
|
||||
def update_user_file_token_count__no_commit(
|
||||
user_file_id_to_token_count: dict[int, int | None],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
for user_file_id, token_count in user_file_id_to_token_count.items():
|
||||
db_session.query(UserFile).filter(UserFile.id == user_file_id).update(
|
||||
{UserFile.token_count: token_count}
|
||||
)
|
||||
@@ -1,15 +1,52 @@
|
||||
"""Factory for creating federated connector instances."""
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.federated_connectors.interfaces import FederatedConnector
|
||||
from onyx.federated_connectors.slack.federated_connector import SlackFederatedConnector
|
||||
from onyx.federated_connectors.registry import FEDERATED_CONNECTOR_CLASS_MAP
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class FederatedConnectorMissingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Cache for already imported federated connector classes
|
||||
_federated_connector_cache: dict[FederatedConnectorSource, Type[FederatedConnector]] = (
|
||||
{}
|
||||
)
|
||||
|
||||
|
||||
def _load_federated_connector_class(
|
||||
source: FederatedConnectorSource,
|
||||
) -> Type[FederatedConnector]:
|
||||
"""Dynamically load and cache a federated connector class."""
|
||||
if source in _federated_connector_cache:
|
||||
return _federated_connector_cache[source]
|
||||
|
||||
if source not in FEDERATED_CONNECTOR_CLASS_MAP:
|
||||
raise FederatedConnectorMissingException(
|
||||
f"Federated connector not found for source={source}"
|
||||
)
|
||||
|
||||
mapping = FEDERATED_CONNECTOR_CLASS_MAP[source]
|
||||
|
||||
try:
|
||||
module = importlib.import_module(mapping.module_path)
|
||||
connector_class = getattr(module, mapping.class_name)
|
||||
_federated_connector_cache[source] = connector_class
|
||||
return connector_class
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise FederatedConnectorMissingException(
|
||||
f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def get_federated_connector(
|
||||
source: FederatedConnectorSource,
|
||||
credentials: dict[str, Any],
|
||||
@@ -21,9 +58,6 @@ def get_federated_connector(
|
||||
|
||||
def get_federated_connector_cls(
|
||||
source: FederatedConnectorSource,
|
||||
) -> type[FederatedConnector]:
|
||||
) -> Type[FederatedConnector]:
|
||||
"""Get the class of the appropriate federated connector."""
|
||||
if source == FederatedConnectorSource.FEDERATED_SLACK:
|
||||
return SlackFederatedConnector
|
||||
else:
|
||||
raise ValueError(f"Unsupported federated connector source: {source}")
|
||||
return _load_federated_connector_class(source)
|
||||
|
||||
@@ -135,12 +135,16 @@ def get_federated_retrieval_functions(
|
||||
# At this point, user_id is guaranteed to be not None since we're in the else branch
|
||||
assert user_id is not None
|
||||
|
||||
# If no source types are specified, don't use any federated connectors
|
||||
if source_types is None:
|
||||
logger.info("No source types specified, skipping all federated connectors")
|
||||
return []
|
||||
|
||||
federated_retrieval_infos: list[FederatedRetrievalInfo] = []
|
||||
federated_oauth_tokens = list_federated_connector_oauth_tokens(db_session, user_id)
|
||||
for oauth_token in federated_oauth_tokens:
|
||||
if (
|
||||
source_types is not None
|
||||
and oauth_token.federated_connector.source.to_non_federated_source()
|
||||
oauth_token.federated_connector.source.to_non_federated_source()
|
||||
not in source_types
|
||||
):
|
||||
continue
|
||||
|
||||
19
backend/onyx/federated_connectors/registry.py
Normal file
19
backend/onyx/federated_connectors/registry.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Registry mapping for federated connector classes."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
|
||||
|
||||
class FederatedConnectorMapping(BaseModel):
|
||||
module_path: str
|
||||
class_name: str
|
||||
|
||||
|
||||
# Mapping of FederatedConnectorSource to connector details for lazy loading
|
||||
FEDERATED_CONNECTOR_CLASS_MAP = {
|
||||
FederatedConnectorSource.FEDERATED_SLACK: FederatedConnectorMapping(
|
||||
module_path="onyx.federated_connectors.slack.federated_connector",
|
||||
class_name="SlackFederatedConnector",
|
||||
),
|
||||
}
|
||||
@@ -54,6 +54,7 @@ ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
".xml",
|
||||
".yml",
|
||||
".yaml",
|
||||
".sql",
|
||||
]
|
||||
|
||||
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
|
||||
from unstructured.staging.base import dict_to_elements
|
||||
from unstructured_client import UnstructuredClient # type: ignore
|
||||
from unstructured_client.models import operations # type: ignore
|
||||
from unstructured_client.models import shared
|
||||
@@ -52,6 +51,8 @@ def _sdk_partition_request(
|
||||
|
||||
|
||||
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
|
||||
from unstructured.staging.base import dict_to_elements
|
||||
|
||||
logger.debug(f"Starting to read file: {file_name}")
|
||||
req = _sdk_partition_request(file, file_name, strategy="fast")
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import NotRequired
|
||||
from uuid import UUID
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -36,7 +35,7 @@ class FileDescriptor(TypedDict):
|
||||
id: str
|
||||
type: ChatFileType
|
||||
name: NotRequired[str | None]
|
||||
user_file_id: NotRequired[UUID | None]
|
||||
user_file_id: NotRequired[str | None]
|
||||
|
||||
|
||||
class InMemoryChatFile(BaseModel):
|
||||
@@ -58,5 +57,5 @@ class InMemoryChatFile(BaseModel):
|
||||
"id": str(self.file_id),
|
||||
"type": self.file_type,
|
||||
"name": self.filename,
|
||||
"user_file_id": UUID(str(self.file_id)) if self.file_id else None,
|
||||
"user_file_id": str(self.file_id) if self.file_id else None,
|
||||
}
|
||||
|
||||
@@ -5,8 +5,9 @@ from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
import litellm # type: ignore
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
@@ -24,9 +25,7 @@ from langchain_core.messages import SystemMessageChunk
|
||||
from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
@@ -45,13 +44,9 @@ from onyx.utils.long_term_log import LongTermLogger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ModelResponse, CustomStreamWrapper, Message
|
||||
|
||||
if BRAINTRUST_ENABLED:
|
||||
litellm.callbacks = ["braintrust"]
|
||||
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
@@ -85,8 +80,10 @@ def _base_msg_to_role(msg: BaseMessage) -> str:
|
||||
|
||||
|
||||
def _convert_litellm_message_to_langchain_message(
|
||||
litellm_message: litellm.Message,
|
||||
litellm_message: "Message",
|
||||
) -> BaseMessage:
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
|
||||
# Extracting the basic attributes from the litellm message
|
||||
content = litellm_message.content or ""
|
||||
role = litellm_message.role
|
||||
@@ -176,15 +173,15 @@ def _convert_delta_to_message_chunk(
|
||||
curr_msg: BaseMessage | None,
|
||||
stop_reason: str | None = None,
|
||||
) -> BaseMessageChunk:
|
||||
from litellm.utils import ChatCompletionDeltaToolCall
|
||||
|
||||
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
|
||||
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else "unknown")
|
||||
content = _dict.get("content") or ""
|
||||
additional_kwargs = {}
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs.update({"function_call": dict(_dict["function_call"])})
|
||||
tool_calls = cast(
|
||||
list[litellm.utils.ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls")
|
||||
)
|
||||
tool_calls = cast(list[ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls"))
|
||||
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
@@ -321,6 +318,8 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
self._max_token_param = LEGACY_MAX_TOKENS_KWARG
|
||||
try:
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
params = get_supported_openai_params(model_name, model_provider)
|
||||
if STANDARD_MAX_TOKENS_KWARG in (params or []):
|
||||
self._max_token_param = STANDARD_MAX_TOKENS_KWARG
|
||||
@@ -388,11 +387,12 @@ class DefaultMultiLLM(LLM):
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
) -> Union["ModelResponse", "CustomStreamWrapper"]:
|
||||
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
|
||||
# to a dict representation
|
||||
processed_prompt = _prompt_to_dict(prompt)
|
||||
self._record_call(processed_prompt)
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
|
||||
try:
|
||||
return litellm.completion(
|
||||
@@ -437,6 +437,16 @@ class DefaultMultiLLM(LLM):
|
||||
]
|
||||
else {}
|
||||
), # TODO: remove once LITELLM has patched
|
||||
**(
|
||||
{"reasoning_effort": "minimal"}
|
||||
if self.config.model_name
|
||||
in [
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
]
|
||||
else {}
|
||||
), # TODO: remove once LITELLM has better support/we change API
|
||||
**(
|
||||
{"response_format": structured_response_format}
|
||||
if structured_response_format
|
||||
@@ -485,11 +495,13 @@ class DefaultMultiLLM(LLM):
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> BaseMessage:
|
||||
from litellm import ModelResponse
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
response = cast(
|
||||
litellm.ModelResponse,
|
||||
ModelResponse,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
@@ -518,6 +530,8 @@ class DefaultMultiLLM(LLM):
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
from litellm import CustomStreamWrapper
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
@@ -534,7 +548,7 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
CustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
@@ -13,6 +10,8 @@ from onyx.db.models import Persona
|
||||
from onyx.llm.chat_llm import DefaultMultiLLM
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.llm_provider_options import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
@@ -24,13 +23,22 @@ from onyx.utils.long_term_log import LongTermLogger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
|
||||
"""Ollama requires us to specify the max context window.
|
||||
def _build_provider_extra_headers(
|
||||
provider: str, custom_config: dict[str, str] | None
|
||||
) -> dict[str, str]:
|
||||
if provider != OLLAMA_PROVIDER_NAME or not custom_config:
|
||||
return {}
|
||||
|
||||
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
|
||||
TODO: allow model-specific values to be configured via the UI.
|
||||
"""
|
||||
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
|
||||
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
|
||||
|
||||
api_key = raw_api_key.strip() if raw_api_key else None
|
||||
if not api_key:
|
||||
return {}
|
||||
|
||||
if not api_key.lower().startswith("bearer "):
|
||||
api_key = f"Bearer {api_key}"
|
||||
|
||||
return {"Authorization": api_key}
|
||||
|
||||
|
||||
def get_main_llm_from_tuple(
|
||||
@@ -272,6 +280,16 @@ def get_llm(
|
||||
) -> LLM:
|
||||
if temperature is None:
|
||||
temperature = GEN_AI_TEMPERATURE
|
||||
|
||||
extra_headers = build_llm_extra_headers(additional_headers)
|
||||
|
||||
# NOTE: this is needed since Ollama API key is optional
|
||||
# User may access Ollama cloud via locally hosted instance (logged in)
|
||||
# or just via the cloud API (not logged in, using API key)
|
||||
provider_extra_headers = _build_provider_extra_headers(provider, custom_config)
|
||||
if provider_extra_headers:
|
||||
extra_headers.update(provider_extra_headers)
|
||||
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
@@ -282,8 +300,8 @@ def get_llm(
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
custom_config=custom_config,
|
||||
extra_headers=build_llm_extra_headers(additional_headers),
|
||||
model_kwargs=_build_extra_model_kwargs(provider),
|
||||
extra_headers=extra_headers,
|
||||
model_kwargs={},
|
||||
long_term_logger=long_term_logger,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
|
||||
23
backend/onyx/llm/litellm_singleton.py
Normal file
23
backend/onyx/llm/litellm_singleton.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Singleton module for litellm configuration.
|
||||
This ensures litellm is configured exactly once when first imported.
|
||||
All other modules should import litellm from here instead of directly.
|
||||
"""
|
||||
|
||||
import litellm
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
|
||||
# Import litellm
|
||||
|
||||
# Configure litellm settings immediately on import
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
|
||||
if BRAINTRUST_ENABLED:
|
||||
litellm.callbacks = ["braintrust"]
|
||||
|
||||
# Export the configured litellm module
|
||||
__all__ = ["litellm"]
|
||||
@@ -39,6 +39,7 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
model_configurations: list[ModelConfigurationView]
|
||||
default_model: str | None = None
|
||||
default_fast_model: str | None = None
|
||||
default_api_base: str | None = None
|
||||
# set for providers like Azure, which require a deployment name.
|
||||
deployment_name_required: bool = False
|
||||
# set for providers like Azure, which support a single model per deployment.
|
||||
@@ -95,7 +96,9 @@ BEDROCK_MODEL_NAMES = [
|
||||
for model in list(litellm.bedrock_models.union(litellm.bedrock_converse_models))
|
||||
if "/" not in model and "embed" not in model
|
||||
][::-1]
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
OLLAMA_PROVIDER_NAME = "ollama"
|
||||
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
|
||||
IGNORABLE_ANTHROPIC_MODELS = [
|
||||
"claude-2",
|
||||
@@ -109,8 +112,8 @@ ANTHROPIC_MODEL_NAMES = [
|
||||
if model not in IGNORABLE_ANTHROPIC_MODELS
|
||||
][::-1]
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = [
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-20250514",
|
||||
]
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
@@ -160,13 +163,15 @@ _PROVIDER_TO_MODELS_MAP = {
|
||||
BEDROCK_PROVIDER_NAME: BEDROCK_MODEL_NAMES,
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_MODEL_NAMES,
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_MODEL_NAMES,
|
||||
OLLAMA_PROVIDER_NAME: [],
|
||||
}
|
||||
|
||||
_PROVIDER_TO_VISIBLE_MODELS_MAP = {
|
||||
OPENAI_PROVIDER_NAME: OPEN_AI_VISIBLE_MODEL_NAMES,
|
||||
BEDROCK_PROVIDER_NAME: [BEDROCK_DEFAULT_MODEL],
|
||||
BEDROCK_PROVIDER_NAME: [],
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_VISIBLE_MODEL_NAMES,
|
||||
OLLAMA_PROVIDER_NAME: [],
|
||||
}
|
||||
|
||||
|
||||
@@ -185,6 +190,28 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
default_model="gpt-4o",
|
||||
default_fast_model="gpt-4o-mini",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=OLLAMA_PROVIDER_NAME,
|
||||
display_name="Ollama",
|
||||
api_key_required=False,
|
||||
api_base_required=True,
|
||||
api_version_required=False,
|
||||
custom_config_keys=[
|
||||
CustomConfigKey(
|
||||
name=OLLAMA_API_KEY_CONFIG_KEY,
|
||||
display_name="Ollama API Key",
|
||||
description="Optional API key used when connecting to Ollama Cloud (i.e. API base is https://ollama.com).",
|
||||
is_required=False,
|
||||
is_secret=True,
|
||||
)
|
||||
],
|
||||
model_configurations=fetch_model_configurations_for_provider(
|
||||
OLLAMA_PROVIDER_NAME
|
||||
),
|
||||
default_model=None,
|
||||
default_fast_model=None,
|
||||
default_api_base="http://127.0.0.1:11434",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=ANTHROPIC_PROVIDER_NAME,
|
||||
display_name="Anthropic",
|
||||
@@ -195,8 +222,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
model_configurations=fetch_model_configurations_for_provider(
|
||||
ANTHROPIC_PROVIDER_NAME
|
||||
),
|
||||
default_model="claude-3-7-sonnet-20250219",
|
||||
default_fast_model="claude-3-5-sonnet-20241022",
|
||||
default_model="claude-sonnet-4-5-20250929",
|
||||
default_fast_model="claude-sonnet-4-20250514",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=AZURE_PROVIDER_NAME,
|
||||
@@ -248,7 +275,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
model_configurations=fetch_model_configurations_for_provider(
|
||||
BEDROCK_PROVIDER_NAME
|
||||
),
|
||||
default_model=BEDROCK_DEFAULT_MODEL,
|
||||
default_model=None,
|
||||
default_fast_model=None,
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
|
||||
@@ -16,6 +16,7 @@ from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
|
||||
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
@@ -26,6 +27,9 @@ from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -460,6 +464,7 @@ def get_llm_contextual_cost(
|
||||
this does not account for the cost of documents that fit within a single chunk
|
||||
which do not get contextualized.
|
||||
"""
|
||||
|
||||
import litellm
|
||||
|
||||
# calculate input costs
|
||||
@@ -639,6 +644,30 @@ def get_max_input_tokens_from_llm_provider(
|
||||
|
||||
|
||||
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
# TODO: Add support to check model config for any provider
|
||||
# TODO: Circular import means OLLAMA_PROVIDER_NAME is not available here
|
||||
|
||||
if model_provider == "ollama":
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.join(
|
||||
LLMProvider,
|
||||
ModelConfiguration.llm_provider_id == LLMProvider.id,
|
||||
)
|
||||
.where(
|
||||
ModelConfiguration.name == model_name,
|
||||
LLMProvider.provider == model_provider,
|
||||
)
|
||||
)
|
||||
if model_config and model_config.supports_image_input is not None:
|
||||
return model_config.supports_image_input
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to query database for {model_provider} model {model_name} image support: {e}"
|
||||
)
|
||||
|
||||
model_map = get_model_map()
|
||||
try:
|
||||
model_obj = find_model_obj(
|
||||
|
||||
@@ -714,13 +714,15 @@ information that will be necessary to provide a succinct answer to the specific
|
||||
the documents. Again, start out here as well with a brief statement whether the SPECIFIC CONTEXT is \
|
||||
mentioned in the documents. (Example: 'I was not able to find information about yellow curry specifically, \
|
||||
but I found information about curry...'). But this should be be precise and concise, and specifically \
|
||||
answer the question. Please cite the document sources inline in format [[1]][[7]], etc.>",
|
||||
answer the question. Please cite the document sources inline in format [[1]][[7]], etc., where it \
|
||||
is essential that the document NUMBERS are in the brackets, not any titles.>",
|
||||
"claims": "<a list of short claims discussed in the documents as they pertain to the query and/or \
|
||||
the original question. These will later be used for follow-up questions and verifications. Note that \
|
||||
these may not actually be in the succinct answer above. Note also that each claim \
|
||||
should include ONE fact that contains enough context to be verified/questioned by a different system \
|
||||
without the need for going back to these documents for additional context. Also here, please cite the \
|
||||
document sources inline in format [[1]][[7]], etc.. So this should have format like \
|
||||
document sources inline in format [[1]][[7]], etc., where it is essential that the document NUMBERS are \
|
||||
in the brackets, not any titles. So this should have format like \
|
||||
[<claim 1>, <claim 2>, <claim 3>, ...], each with citations.>"
|
||||
}
|
||||
"""
|
||||
@@ -1043,8 +1045,9 @@ find information about yellow curry specifically, but here is what I found about
|
||||
- do not make anything up! Only use the information provided in the documents, or, \
|
||||
if no documents are provided for a sub-answer, in the actual sub-answer.
|
||||
- Provide a thoughtful answer that is concise and to the point, but that is detailed.
|
||||
- Please cite your sources INLINE in format [[2]][[4]], etc! The numbers of the documents \
|
||||
are provided above. So the appropriate citation number should be close to the corresponding /
|
||||
- Please cite your sources INLINE in format [[2]][[4]], etc! The NUMBERS of the documents \
|
||||
are provided above, and the NUMBERS need to be in the brackets. And the appropriate citation \
|
||||
should be close to the corresponding /
|
||||
information it supports!
|
||||
- If you are not that certain that the information does relate to the question topic, \
|
||||
point out the ambiguity in your answer. But DO NOT say something like 'I was not able to find \
|
||||
@@ -1098,14 +1101,16 @@ find information about yellow curry specifically, but here is what I found about
|
||||
- do not make anything up! Only use the information provided in the documents, or, \
|
||||
if no documents are provided for a sub-answer, in the actual sub-answer.
|
||||
- Provide a thoughtful answer that is concise and to the point, but that is detailed.
|
||||
- Please cite your sources inline in format [[2]][[4]], etc! The numbers of the documents \
|
||||
are provided above. So the appropriate citation number should be close to the corresponding /
|
||||
- Please cite your sources inline in format [[2]][[4]], etc! The NUMBERS of the documents \
|
||||
are provided above, and the NUMBERS need to be in the brackets. And the appropriate citation \
|
||||
should be close to the corresponding /
|
||||
information it supports!
|
||||
- If you are not that certain that the information does relate to the question topic, \
|
||||
point out the ambiguity in your answer. But DO NOT say something like 'I was not able to find \
|
||||
information on <X> specifically, but here is what I found about <X> generally....'. Rather say, \
|
||||
'Here is what I found about <X> and I hope this is the <X> you were looking for...', or similar.
|
||||
- Again... CITE YOUR SOURCES INLINE IN FORMAT [[2]][[4]], etc! This is CRITICAL!
|
||||
- Again... CITE YOUR SOURCES INLINE IN FORMAT [[2]][[4]], etc! This is CRITICAL! Note that \
|
||||
the DOCUMENT NUMBERS need to be in the brackets.
|
||||
|
||||
ANSWER:
|
||||
"""
|
||||
@@ -1150,8 +1155,9 @@ find information about yellow curry specifically, but here is what I found about
|
||||
- do not make anything up! Only use the information provided in the documents, or, \
|
||||
if no documents are provided for a sub-answer, in the actual sub-answer.
|
||||
- Provide a thoughtful answer that is concise and to the point, but that is detailed.
|
||||
- THIS IS VERY IMPORTANT: Please cite your sources inline in format [[2]][[4]], etc! The numbers of the documents \
|
||||
are provided above. Also, if you refer to sub-answers, the provided reference numbers \
|
||||
- THIS IS VERY IMPORTANT: Please cite your sources inline in format [[2]][[4]], etc! \
|
||||
The NUMBERS of the documents - provided above -need to be in the brackets. \
|
||||
Also, if you refer to sub-answers, the provided reference numbers \
|
||||
in the sub-answers are the same as the ones provided for the documents!
|
||||
|
||||
ANSWER:
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import List
|
||||
|
||||
from litellm import get_supported_openai_params
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import NUM_PERSONA_PROMPT_GENERATION_CHUNKS
|
||||
@@ -123,6 +122,8 @@ def generate_starter_messages(
|
||||
"""
|
||||
_, fast_llm = get_default_llms(temperature=0.5)
|
||||
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
provider = fast_llm.config.model_provider
|
||||
model = fast_llm.config.model_name
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ def seed_initial_documents(
|
||||
"base_url": "https://docs.onyx.app/",
|
||||
"web_connector_type": "recursive",
|
||||
},
|
||||
refresh_freq=None, # Never refresh by default
|
||||
refresh_freq=3600, # 1 hour
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
user_folders:
|
||||
- id: -1
|
||||
name: "Recent Documents"
|
||||
description: "Documents uploaded by the user"
|
||||
files: []
|
||||
assistants: []
|
||||
@@ -22,7 +22,7 @@ from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.subclasses import find_all_subclasses_in_dir
|
||||
from onyx.utils.subclasses import find_all_subclasses_in_package
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -44,7 +44,8 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
|
||||
if _OAUTH_CONNECTORS: # Return cached connectors if already discovered
|
||||
return _OAUTH_CONNECTORS
|
||||
|
||||
oauth_connectors = find_all_subclasses_in_dir(
|
||||
# Import submodules using package-based discovery to avoid sys.path mutations
|
||||
oauth_connectors = find_all_subclasses_in_package(
|
||||
cast(type[OAuthConnector], OAuthConnector), "onyx.connectors"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Path
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.folder import add_chat_to_folder
|
||||
from onyx.db.folder import create_folder
|
||||
from onyx.db.folder import delete_folder
|
||||
from onyx.db.folder import get_user_folders
|
||||
from onyx.db.folder import remove_chat_from_folder
|
||||
from onyx.db.folder import rename_folder
|
||||
from onyx.db.folder import update_folder_display_priority
|
||||
from onyx.db.models import User
|
||||
from onyx.server.features.folder.models import DeleteFolderOptions
|
||||
from onyx.server.features.folder.models import FolderChatSessionRequest
|
||||
from onyx.server.features.folder.models import FolderCreationRequest
|
||||
from onyx.server.features.folder.models import FolderUpdateRequest
|
||||
from onyx.server.features.folder.models import GetUserFoldersResponse
|
||||
from onyx.server.features.folder.models import UserFolderSnapshot
|
||||
from onyx.server.models import DisplayPriorityRequest
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
|
||||
router = APIRouter(prefix="/folder")
|
||||
|
||||
|
||||
@router.get("")
|
||||
def get_folders(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> GetUserFoldersResponse:
|
||||
folders = get_user_folders(
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
folders.sort()
|
||||
return GetUserFoldersResponse(
|
||||
folders=[
|
||||
UserFolderSnapshot(
|
||||
folder_id=folder.id,
|
||||
folder_name=folder.name,
|
||||
display_priority=folder.display_priority,
|
||||
chat_sessions=[
|
||||
ChatSessionDetails(
|
||||
id=chat_session.id,
|
||||
name=chat_session.description,
|
||||
persona_id=chat_session.persona_id,
|
||||
time_created=chat_session.time_created.isoformat(),
|
||||
time_updated=chat_session.time_updated.isoformat(),
|
||||
shared_status=chat_session.shared_status,
|
||||
folder_id=folder.id,
|
||||
)
|
||||
for chat_session in folder.chat_sessions
|
||||
if not chat_session.deleted
|
||||
],
|
||||
)
|
||||
for folder in folders
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.put("/reorder")
|
||||
def put_folder_display_priority(
|
||||
display_priority_request: DisplayPriorityRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_folder_display_priority(
|
||||
user_id=user.id if user else None,
|
||||
display_priority_map=display_priority_request.display_priority_map,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@router.post("")
|
||||
def create_folder_endpoint(
|
||||
request: FolderCreationRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> int:
|
||||
return create_folder(
|
||||
user_id=user.id if user else None,
|
||||
folder_name=request.folder_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{folder_id}")
|
||||
def patch_folder_endpoint(
|
||||
request: FolderUpdateRequest,
|
||||
folder_id: int = Path(..., description="The ID of the folder to rename"),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
rename_folder(
|
||||
user_id=user.id if user else None,
|
||||
folder_id=folder_id,
|
||||
folder_name=request.folder_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{folder_id}")
|
||||
def delete_folder_endpoint(
|
||||
request: DeleteFolderOptions,
|
||||
folder_id: int = Path(..., description="The ID of the folder to delete"),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
delete_folder(
|
||||
user_id=user_id,
|
||||
folder_id=folder_id,
|
||||
including_chats=request.including_chats,
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{folder_id}/add-chat-session")
|
||||
def add_chat_to_folder_endpoint(
|
||||
request: FolderChatSessionRequest,
|
||||
folder_id: int = Path(
|
||||
..., description="The ID of the folder in which to add the chat session"
|
||||
),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=request.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
add_chat_to_folder(
|
||||
user_id=user.id if user else None,
|
||||
folder_id=folder_id,
|
||||
chat_session=chat_session,
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{folder_id}/remove-chat-session")
|
||||
def remove_chat_from_folder_endpoint(
|
||||
request: FolderChatSessionRequest,
|
||||
folder_id: int = Path(
|
||||
..., description="The ID of the folder from which to remove the chat session"
|
||||
),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=request.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
remove_chat_from_folder(
|
||||
user_id=user_id,
|
||||
folder_id=folder_id,
|
||||
chat_session=chat_session,
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -1,32 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
|
||||
|
||||
class UserFolderSnapshot(BaseModel):
|
||||
folder_id: int
|
||||
folder_name: str | None
|
||||
display_priority: int
|
||||
chat_sessions: list[ChatSessionDetails]
|
||||
|
||||
|
||||
class GetUserFoldersResponse(BaseModel):
|
||||
folders: list[UserFolderSnapshot]
|
||||
|
||||
|
||||
class FolderCreationRequest(BaseModel):
|
||||
folder_name: str | None = None
|
||||
|
||||
|
||||
class FolderUpdateRequest(BaseModel):
|
||||
folder_name: str | None = None
|
||||
|
||||
|
||||
class FolderChatSessionRequest(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
class DeleteFolderOptions(BaseModel):
|
||||
including_chats: bool = False
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import NotRequired
|
||||
from typing import Optional
|
||||
@@ -10,20 +12,35 @@ from pydantic import model_validator
|
||||
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
|
||||
|
||||
# This should be updated along with MCPConnectionData
|
||||
class MCPOAuthKeys(str, Enum):
|
||||
"""MCP OAuth keys types"""
|
||||
|
||||
CLIENT_INFO = "client_info"
|
||||
TOKENS = "tokens"
|
||||
METADATA = "metadata"
|
||||
|
||||
|
||||
class MCPConnectionData(TypedDict):
|
||||
"""TypedDict to allow use as a type hint for a JSONB column
|
||||
in Postgres"""
|
||||
|
||||
refresh_token: NotRequired[str]
|
||||
access_token: NotRequired[str]
|
||||
headers: dict[str, str]
|
||||
header_substitutions: NotRequired[dict[str, str]]
|
||||
client_id: NotRequired[str]
|
||||
client_secret: NotRequired[str]
|
||||
registration_access_token: NotRequired[str]
|
||||
registration_client_uri: NotRequired[str]
|
||||
|
||||
# For OAuth only
|
||||
# Note: Update MCPOAuthKeys if necessary when modifying these
|
||||
# Unfortunately we can't use the actual models here because basemodels aren't compatible
|
||||
# with SQLAlchemy
|
||||
client_info: NotRequired[dict[str, Any]] # OAuthClientInformationFull
|
||||
tokens: NotRequired[dict[str, Any]] # OAuthToken
|
||||
metadata: NotRequired[dict[str, Any]] # OAuthClientMetadata
|
||||
|
||||
# the actual models are defined in mcp.shared.auth
|
||||
# from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
|
||||
|
||||
class MCPAuthTemplate(BaseModel):
|
||||
@@ -48,14 +65,17 @@ class MCPToolCreateRequest(BaseModel):
|
||||
description: Optional[str] = Field(None, description="Description of the MCP tool")
|
||||
server_url: str = Field(..., description="URL of the MCP server")
|
||||
auth_type: MCPAuthenticationType = Field(..., description="Authentication type")
|
||||
auth_performer: Optional[MCPAuthenticationPerformer] = Field(
|
||||
None, description="Who performs authentication"
|
||||
auth_performer: MCPAuthenticationPerformer = Field(
|
||||
..., description="Who performs authentication"
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
None, description="API token for api_token auth type"
|
||||
)
|
||||
oauth_client_id: Optional[str] = Field(None, description="OAuth client ID")
|
||||
oauth_client_secret: Optional[str] = Field(None, description="OAuth client secret")
|
||||
transport: MCPTransport | None = Field(
|
||||
None, description="MCP transport type (STREAMABLE_HTTP or SSE)"
|
||||
)
|
||||
auth_template: Optional[MCPAuthTemplate] = Field(
|
||||
None, description="Template configuration for per-user authentication"
|
||||
)
|
||||
@@ -104,15 +124,9 @@ class MCPToolCreateRequest(BaseModel):
|
||||
"admin_credentials is required when auth_performer is 'per_user'"
|
||||
)
|
||||
|
||||
if self.auth_type == MCPAuthenticationType.OAUTH and not self.oauth_client_id:
|
||||
raise ValueError("oauth_client_id is required when auth_type is 'oauth'")
|
||||
if (
|
||||
self.auth_type == MCPAuthenticationType.OAUTH
|
||||
and not self.oauth_client_secret
|
||||
):
|
||||
raise ValueError(
|
||||
"oauth_client_secret is required when auth_type is 'oauth'"
|
||||
)
|
||||
# OAuth client ID/secret are optional. If provided, they will seed the
|
||||
# OAuth client info; otherwise, the MCP client will attempt dynamic
|
||||
# client registration.
|
||||
|
||||
return self
|
||||
|
||||
@@ -140,7 +154,7 @@ class MCPToolResponse(BaseModel):
|
||||
is_authenticated: bool
|
||||
|
||||
|
||||
class MCPOAuthInitiateRequest(BaseModel):
|
||||
class MCPOAuthConnectRequest(BaseModel):
|
||||
name: str = Field(..., description="Name of the MCP tool")
|
||||
description: Optional[str] = Field(None, description="Description of the MCP tool")
|
||||
server_url: str = Field(..., description="URL of the MCP server")
|
||||
@@ -152,32 +166,33 @@ class MCPOAuthInitiateRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class MCPOAuthInitiateResponse(BaseModel):
|
||||
class MCPOAuthConnectResponse(BaseModel):
|
||||
oauth_url: str = Field(..., description="OAuth URL to redirect user to")
|
||||
state: str = Field(..., description="OAuth state parameter")
|
||||
pending_tool: dict = Field(..., description="Pending tool configuration")
|
||||
|
||||
|
||||
class MCPUserOAuthInitiateRequest(BaseModel):
|
||||
class MCPUserOAuthConnectRequest(BaseModel):
|
||||
server_id: int = Field(..., description="ID of the MCP server")
|
||||
return_path: str = Field(..., description="Path to redirect to after callback")
|
||||
include_resource_param: bool = Field(..., description="Include resource parameter")
|
||||
oauth_client_id: str | None = Field(
|
||||
None, description="OAuth client ID (optional for DCR)"
|
||||
)
|
||||
oauth_client_secret: str | None = Field(
|
||||
None, description="OAuth client secret (optional for DCR)"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_return_path(self) -> "MCPUserOAuthInitiateRequest":
|
||||
def validate_return_path(self) -> "MCPUserOAuthConnectRequest":
|
||||
if not self.return_path.startswith("/"):
|
||||
raise ValueError("return_path must start with a slash")
|
||||
return self
|
||||
|
||||
|
||||
class MCPUserOAuthInitiateResponse(BaseModel):
|
||||
class MCPUserOAuthConnectResponse(BaseModel):
|
||||
server_id: int
|
||||
oauth_url: str = Field(..., description="OAuth URL to redirect user to")
|
||||
state: str = Field(..., description="OAuth state parameter")
|
||||
server_id: int = Field(..., description="Server ID")
|
||||
server_name: str = Field(..., description="Server name")
|
||||
code_verifier: Optional[str] = Field(
|
||||
None, description="PKCE code verifier to be used at callback"
|
||||
)
|
||||
|
||||
|
||||
class MCPOAuthCallbackRequest(BaseModel):
|
||||
@@ -194,7 +209,6 @@ class MCPOAuthCallbackResponse(BaseModel):
|
||||
message: str
|
||||
server_id: int
|
||||
server_name: str
|
||||
authenticated: bool
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@@ -255,6 +269,7 @@ class MCPServer(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
server_url: str
|
||||
transport: MCPTransport
|
||||
auth_type: MCPAuthenticationType
|
||||
auth_performer: MCPAuthenticationPerformer
|
||||
is_authenticated: bool
|
||||
|
||||
@@ -162,7 +162,7 @@ def unlink_user_file_from_project(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
|
||||
@@ -210,7 +210,7 @@ def link_user_file_to_project(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import boto3
|
||||
import httpx
|
||||
from botocore.exceptions import BotoCoreError
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
@@ -11,10 +12,12 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
@@ -40,6 +43,9 @@ from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -474,3 +480,100 @@ def get_bedrock_available_models(
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Unexpected error fetching Bedrock models: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _get_ollama_available_model_names(api_base: str) -> set[str]:
|
||||
"""Fetch available model names from Ollama server."""
|
||||
tags_url = f"{api_base}/api/tags"
|
||||
try:
|
||||
response = httpx.get(tags_url, timeout=5.0)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch Ollama models: {e}",
|
||||
)
|
||||
|
||||
models = response_json.get("models", [])
|
||||
return {model.get("name") for model in models if model.get("name")}
|
||||
|
||||
|
||||
@admin_router.post("/ollama/available-models")
|
||||
def get_ollama_available_models(
|
||||
request: OllamaModelsRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[OllamaFinalModelResponse]:
|
||||
"""Fetch the list of available models from an Ollama server."""
|
||||
|
||||
cleaned_api_base = request.api_base.strip().rstrip("/")
|
||||
if not cleaned_api_base:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="API base URL is required to fetch Ollama models."
|
||||
)
|
||||
|
||||
model_names = _get_ollama_available_model_names(cleaned_api_base)
|
||||
if not model_names:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No models found from your Ollama server",
|
||||
)
|
||||
|
||||
all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = []
|
||||
show_url = f"{cleaned_api_base}/api/show"
|
||||
|
||||
for model_name in model_names:
|
||||
context_limit: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
try:
|
||||
show_response = httpx.post(
|
||||
show_url,
|
||||
json={"model": model_name},
|
||||
timeout=5.0,
|
||||
)
|
||||
show_response.raise_for_status()
|
||||
show_response_json = show_response.json()
|
||||
|
||||
# Parse the response into the expected format
|
||||
ollama_model_details = OllamaModelDetails.model_validate(show_response_json)
|
||||
|
||||
# Check if this model supports completion/chat
|
||||
if not ollama_model_details.supports_completion():
|
||||
continue
|
||||
|
||||
# Optimistically access. Context limit is stored as "model_architecture.context" = int
|
||||
architecture = ollama_model_details.model_info.get(
|
||||
"general.architecture", ""
|
||||
)
|
||||
context_limit = ollama_model_details.model_info.get(
|
||||
architecture + ".context_length", None
|
||||
)
|
||||
supports_image_input = ollama_model_details.supports_image_input()
|
||||
except ValidationError as e:
|
||||
logger.warning(
|
||||
"Invalid model details from Ollama server",
|
||||
extra={"model": model_name, "validation_error": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to fetch Ollama model details",
|
||||
extra={"model": model_name, "error": str(e)},
|
||||
)
|
||||
|
||||
# If we fail at any point attempting to extract context limit,
|
||||
# still allow this model to be used with a fallback max context size
|
||||
if not context_limit:
|
||||
context_limit = GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
|
||||
if not supports_image_input:
|
||||
supports_image_input = False
|
||||
|
||||
all_models_with_context_size_and_vision.append(
|
||||
OllamaFinalModelResponse(
|
||||
name=model_name,
|
||||
max_input_tokens=context_limit,
|
||||
supports_image_input=supports_image_input,
|
||||
)
|
||||
)
|
||||
|
||||
return all_models_with_context_size_and_vision
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -138,8 +139,9 @@ class LLMProviderView(LLMProvider):
|
||||
|
||||
class ModelConfigurationUpsertRequest(BaseModel):
|
||||
name: str
|
||||
is_visible: bool | None = False
|
||||
is_visible: bool
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -149,12 +151,13 @@ class ModelConfigurationUpsertRequest(BaseModel):
|
||||
name=model_configuration_model.name,
|
||||
is_visible=model_configuration_model.is_visible,
|
||||
max_input_tokens=model_configuration_model.max_input_tokens,
|
||||
supports_image_input=model_configuration_model.supports_image_input,
|
||||
)
|
||||
|
||||
|
||||
class ModelConfigurationView(BaseModel):
|
||||
name: str
|
||||
is_visible: bool | None = False
|
||||
is_visible: bool
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool
|
||||
|
||||
@@ -196,3 +199,28 @@ class BedrockModelsRequest(BaseModel):
|
||||
aws_secret_access_key: str | None = None
|
||||
aws_bearer_token_bedrock: str | None = None
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class OllamaModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
|
||||
|
||||
class OllamaFinalModelResponse(BaseModel):
|
||||
name: str
|
||||
max_input_tokens: int
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class OllamaModelDetails(BaseModel):
|
||||
"""Response model for Ollama /api/show endpoint"""
|
||||
|
||||
model_info: dict[str, Any]
|
||||
capabilities: list[str] = []
|
||||
|
||||
def supports_completion(self) -> bool:
|
||||
"""Check if this model supports completion/chat"""
|
||||
return "completion" in self.capabilities
|
||||
|
||||
def supports_image_input(self) -> bool:
|
||||
"""Check if this model supports image input"""
|
||||
return "vision" in self.capabilities
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import csv
|
||||
import io
|
||||
import re
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -14,6 +16,7 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -297,6 +300,43 @@ def list_all_users(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/manage/users/download")
|
||||
def download_users_csv(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
"""Download all users as a CSV file."""
|
||||
# Get all users from the database
|
||||
users = get_all_users(db_session)
|
||||
|
||||
# Create CSV content in memory
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Write CSV header
|
||||
writer.writerow(["Email", "Role", "Status"])
|
||||
|
||||
# Write user data
|
||||
for user in users:
|
||||
writer.writerow(
|
||||
[
|
||||
user.email,
|
||||
user.role.value if user.role else "",
|
||||
"Active" if user.is_active else "Inactive",
|
||||
]
|
||||
)
|
||||
|
||||
# Prepare the CSV content for download
|
||||
csv_content = output.getvalue()
|
||||
output.close()
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(csv_content.encode("utf-8")),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment;"},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/manage/admin/users")
|
||||
def bulk_invite_users(
|
||||
emails: list[str] = Body(..., embed=True),
|
||||
|
||||
@@ -1,596 +0,0 @@
|
||||
import io
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
import sqlalchemy.exc
|
||||
from bs4 import BeautifulSoup
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import create_connector
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
from onyx.db.user_documents import calculate_user_files_token_count
|
||||
from onyx.db.user_documents import create_user_files
|
||||
from onyx.db.user_documents import get_user_file_indexing_status
|
||||
from onyx.db.user_documents import share_file_with_assistant
|
||||
from onyx.db.user_documents import share_folder_with_assistant
|
||||
from onyx.db.user_documents import unshare_file_with_assistant
|
||||
from onyx.db.user_documents import unshare_folder_with_assistant
|
||||
from onyx.db.user_documents import upload_files_to_user_files_with_indexing
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.server.documents.connector import trigger_indexing_for_cc_pair
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.server.query_and_chat.chat_backend import RECENT_DOCS_FOLDER_ID
|
||||
from onyx.server.user_documents.models import MessageResponse
|
||||
from onyx.server.user_documents.models import UserFileSnapshot
|
||||
from onyx.server.user_documents.models import UserFolderSnapshot
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class FolderCreationRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
@router.post("/user/folder")
|
||||
def create_folder(
|
||||
request: FolderCreationRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFolderSnapshot:
|
||||
try:
|
||||
new_folder = UserFolder(
|
||||
user_id=user.id if user else None,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
)
|
||||
db_session.add(new_folder)
|
||||
db_session.commit()
|
||||
return UserFolderSnapshot.from_model(new_folder)
|
||||
except sqlalchemy.exc.DataError as e:
|
||||
if "StringDataRightTruncation" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Folder name or description is too long. Please use a shorter name or description.",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user/folder",
|
||||
)
|
||||
def user_get_folders(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserFolderSnapshot]:
|
||||
user_id = user.id if user else None
|
||||
# Get folders that belong to the user or have the RECENT_DOCS_FOLDER_ID
|
||||
folders = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(
|
||||
(UserFolder.user_id == user_id) | (UserFolder.id == RECENT_DOCS_FOLDER_ID)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# For each folder, filter files to only include those belonging to the current user
|
||||
result = []
|
||||
for folder in folders:
|
||||
folder_snapshot = UserFolderSnapshot.from_model(folder)
|
||||
folder_snapshot.files = [
|
||||
file for file in folder_snapshot.files if file.user_id == user_id
|
||||
]
|
||||
result.append(folder_snapshot)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/user/folder/{folder_id}")
|
||||
def get_folder(
|
||||
folder_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFolderSnapshot:
|
||||
user_id = user.id if user else None
|
||||
folder = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(
|
||||
UserFolder.id == folder_id,
|
||||
(
|
||||
(UserFolder.user_id == user_id)
|
||||
| (UserFolder.id == RECENT_DOCS_FOLDER_ID)
|
||||
),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="Folder not found")
|
||||
|
||||
folder_snapshot = UserFolderSnapshot.from_model(folder)
|
||||
# Filter files to only include those belonging to the current user
|
||||
folder_snapshot.files = [
|
||||
file for file in folder_snapshot.files if file.user_id == user_id
|
||||
]
|
||||
|
||||
return folder_snapshot
|
||||
|
||||
|
||||
@router.post("/user/file/upload")
|
||||
def upload_user_files(
|
||||
files: List[UploadFile] = File(...),
|
||||
folder_id: int | None = Form(None),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserFileSnapshot]:
|
||||
if folder_id == 0:
|
||||
folder_id = None
|
||||
|
||||
try:
|
||||
# Use our consolidated function that handles indexing properly
|
||||
user_files = upload_files_to_user_files_with_indexing(
|
||||
files, folder_id or RECENT_DOCS_FOLDER_ID, user, db_session
|
||||
)
|
||||
|
||||
return [UserFileSnapshot.from_model(user_file) for user_file in user_files]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading files: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to upload files: {str(e)}")
|
||||
|
||||
|
||||
class FolderUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
@router.put("/user/folder/{folder_id}")
|
||||
def update_folder(
|
||||
folder_id: int,
|
||||
request: FolderUpdateRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFolderSnapshot:
|
||||
user_id = user.id if user else None
|
||||
folder = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(UserFolder.id == folder_id, UserFolder.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="Folder not found")
|
||||
if request.name:
|
||||
folder.name = request.name
|
||||
if request.description:
|
||||
folder.description = request.description
|
||||
db_session.commit()
|
||||
|
||||
return UserFolderSnapshot.from_model(folder)
|
||||
|
||||
|
||||
@router.delete("/user/folder/{folder_id}")
|
||||
def delete_folder(
|
||||
folder_id: int,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageResponse:
|
||||
user_id = user.id if user else None
|
||||
folder = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(UserFolder.id == folder_id, UserFolder.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="Folder not found")
|
||||
db_session.delete(folder)
|
||||
db_session.commit()
|
||||
return MessageResponse(message="Folder deleted successfully")
|
||||
|
||||
|
||||
@router.delete("/user/file/{file_id}")
|
||||
def delete_file(
|
||||
file_id: int,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageResponse:
|
||||
user_id = user.id if user else None
|
||||
file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
db_session.delete(file)
|
||||
db_session.commit()
|
||||
return MessageResponse(message="File deleted successfully")
|
||||
|
||||
|
||||
class FileMoveRequest(BaseModel):
|
||||
new_folder_id: int | None
|
||||
|
||||
|
||||
@router.put("/user/file/{file_id}/move")
|
||||
def move_file(
|
||||
file_id: int,
|
||||
request: FileMoveRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileSnapshot:
|
||||
user_id = user.id if user else None
|
||||
file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
file.folder_id = request.new_folder_id
|
||||
db_session.commit()
|
||||
return UserFileSnapshot.from_model(file)
|
||||
|
||||
|
||||
@router.get("/user/file-system")
|
||||
def get_file_system(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserFolderSnapshot]:
|
||||
user_id = user.id if user else None
|
||||
folders = db_session.query(UserFolder).filter(UserFolder.user_id == user_id).all()
|
||||
return [UserFolderSnapshot.from_model(folder) for folder in folders]
|
||||
|
||||
|
||||
@router.put("/user/file/{file_id}/rename")
|
||||
def rename_file(
|
||||
file_id: int,
|
||||
name: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileSnapshot:
|
||||
user_id = user.id if user else None
|
||||
file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
file.name = name
|
||||
db_session.commit()
|
||||
return UserFileSnapshot.from_model(file)
|
||||
|
||||
|
||||
class ShareRequest(BaseModel):
|
||||
assistant_id: int
|
||||
|
||||
|
||||
@router.post("/user/file/{file_id}/share")
|
||||
def share_file(
|
||||
file_id: int,
|
||||
request: ShareRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageResponse:
|
||||
user_id = user.id if user else None
|
||||
file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
share_file_with_assistant(file_id, request.assistant_id, db_session)
|
||||
return MessageResponse(message="File shared successfully with the assistant")
|
||||
|
||||
|
||||
@router.post("/user/file/{file_id}/unshare")
|
||||
def unshare_file(
|
||||
file_id: int,
|
||||
request: ShareRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageResponse:
|
||||
user_id = user.id if user else None
|
||||
file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
unshare_file_with_assistant(file_id, request.assistant_id, db_session)
|
||||
return MessageResponse(message="File unshared successfully from the assistant")
|
||||
|
||||
|
||||
@router.post("/user/folder/{folder_id}/share")
|
||||
def share_folder(
|
||||
folder_id: int,
|
||||
request: ShareRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageResponse:
|
||||
user_id = user.id if user else None
|
||||
folder = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(UserFolder.id == folder_id, UserFolder.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="Folder not found")
|
||||
|
||||
share_folder_with_assistant(folder_id, request.assistant_id, db_session)
|
||||
return MessageResponse(
|
||||
message="Folder and its files shared successfully with the assistant"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/user/folder/{folder_id}/unshare")
|
||||
def unshare_folder(
|
||||
folder_id: int,
|
||||
request: ShareRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageResponse:
|
||||
user_id = user.id if user else None
|
||||
folder = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(UserFolder.id == folder_id, UserFolder.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="Folder not found")
|
||||
|
||||
unshare_folder_with_assistant(folder_id, request.assistant_id, db_session)
|
||||
return MessageResponse(
|
||||
message="Folder and its files unshared successfully from the assistant"
|
||||
)
|
||||
|
||||
|
||||
class CreateFileFromLinkRequest(BaseModel):
|
||||
url: str
|
||||
folder_id: int | None
|
||||
|
||||
|
||||
@router.post("/user/file/create-from-link")
|
||||
def create_file_from_link(
|
||||
request: CreateFileFromLinkRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserFileSnapshot]:
|
||||
try:
|
||||
response = requests.get(request.url)
|
||||
response.raise_for_status()
|
||||
content = response.text
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
parsed_html = web_html_cleanup(soup, mintlify_cleanup_enabled=False)
|
||||
|
||||
file_name = f"{parsed_html.title or 'Untitled'}.txt"
|
||||
file_content = parsed_html.cleaned_text.encode()
|
||||
|
||||
file = UploadFile(filename=file_name, file=io.BytesIO(file_content))
|
||||
user_files = create_user_files(
|
||||
[file], request.folder_id or -1, user, db_session, link_url=request.url
|
||||
)
|
||||
|
||||
# Create connector and credential (same as in upload_user_files)
|
||||
for user_file in user_files:
|
||||
connector_base = ConnectorBase(
|
||||
name=f"UserFile-{user_file.file_id}-{int(time.time())}",
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [user_file.file_id],
|
||||
"file_names": [user_file.name],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
|
||||
connector = create_connector(
|
||||
db_session=db_session,
|
||||
connector_data=connector_base,
|
||||
)
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name=f"UserFileCredential-{user_file.file_id}-{int(time.time())}",
|
||||
)
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
cc_pair = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
|
||||
access_type=AccessType.PRIVATE,
|
||||
auto_sync_options=None,
|
||||
groups=[],
|
||||
is_user_file=True,
|
||||
)
|
||||
user_file.cc_pair_id = cc_pair.data
|
||||
db_session.commit()
|
||||
|
||||
# Trigger immediate indexing with highest priority
|
||||
tenant_id = get_current_tenant_id()
|
||||
trigger_indexing_for_cc_pair(
|
||||
[], connector.id, False, tenant_id, db_session, is_user_file=True
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return [UserFileSnapshot.from_model(user_file) for user_file in user_files]
|
||||
except requests.RequestException as e:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to fetch URL: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/user/file/indexing-status")
|
||||
def get_files_indexing_status(
|
||||
file_ids: list[int] = Query(...),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[int, bool]:
|
||||
"""Get indexing status for multiple files"""
|
||||
return get_user_file_indexing_status(file_ids, db_session)
|
||||
|
||||
|
||||
@router.get("/user/file/token-estimate")
|
||||
def get_files_token_estimate(
|
||||
file_ids: list[int] = Query([]),
|
||||
folder_ids: list[int] = Query([]),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
"""Get token estimate for files and folders"""
|
||||
total_tokens = calculate_user_files_token_count(file_ids, folder_ids, db_session)
|
||||
return {"total_tokens": total_tokens}
|
||||
|
||||
|
||||
class ReindexFileRequest(BaseModel):
|
||||
file_id: int
|
||||
|
||||
|
||||
@router.post("/user/file/reindex")
|
||||
def reindex_file(
|
||||
request: ReindexFileRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageResponse:
|
||||
user_id = user.id if user else None
|
||||
user_file_to_reindex = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == request.file_id, UserFile.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_file_to_reindex:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
if not user_file_to_reindex.cc_pair_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="File does not have an associated connector-credential pair",
|
||||
)
|
||||
|
||||
# Get the connector id from the cc_pair
|
||||
cc_pair = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter_by(id=user_file_to_reindex.cc_pair_id)
|
||||
.first()
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Associated connector-credential pair not found"
|
||||
)
|
||||
|
||||
# Trigger immediate reindexing with highest priority
|
||||
tenant_id = get_current_tenant_id()
|
||||
# Update the cc_pair status to ACTIVE to ensure it's processed
|
||||
cc_pair.status = ConnectorCredentialPairStatus.ACTIVE
|
||||
db_session.commit()
|
||||
try:
|
||||
trigger_indexing_for_cc_pair(
|
||||
[], cc_pair.connector_id, True, tenant_id, db_session, is_user_file=True
|
||||
)
|
||||
return MessageResponse(
|
||||
message="File reindexing has been triggered successfully"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error triggering reindexing for file {request.file_id}: {str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to trigger reindexing: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class BulkCleanupRequest(BaseModel):
|
||||
folder_id: int
|
||||
days_older_than: int | None = None
|
||||
|
||||
|
||||
@router.post("/user/file/bulk-cleanup")
|
||||
def bulk_cleanup_files(
|
||||
request: BulkCleanupRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageResponse:
|
||||
"""Bulk delete files older than specified days in a folder"""
|
||||
user_id = user.id if user else None
|
||||
|
||||
logger.info(
|
||||
f"Bulk cleanup request: folder_id={request.folder_id}, days_older_than={request.days_older_than}"
|
||||
)
|
||||
|
||||
# Check if folder exists
|
||||
if request.folder_id != RECENT_DOCS_FOLDER_ID:
|
||||
folder = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(UserFolder.id == request.folder_id, UserFolder.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="Folder not found")
|
||||
|
||||
filter_criteria = [UserFile.user_id == user_id]
|
||||
|
||||
# Filter by folder
|
||||
if request.folder_id != -2: # -2 means all folders
|
||||
filter_criteria.append(UserFile.folder_id == request.folder_id)
|
||||
|
||||
# Filter by date if days_older_than is provided
|
||||
if request.days_older_than is not None:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=request.days_older_than)
|
||||
logger.info(f"Filtering files older than {cutoff_date} (UTC)")
|
||||
filter_criteria.append(UserFile.created_at < cutoff_date)
|
||||
|
||||
# Get all files matching the criteria
|
||||
files_to_delete = db_session.query(UserFile).filter(*filter_criteria).all()
|
||||
|
||||
logger.info(f"Found {len(files_to_delete)} files to delete")
|
||||
|
||||
# Delete files
|
||||
delete_count = 0
|
||||
for file in files_to_delete:
|
||||
logger.debug(
|
||||
f"Deleting file: id={file.id}, name={file.name}, created_at={file.created_at}"
|
||||
)
|
||||
db_session.delete(file)
|
||||
delete_count += 1
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return MessageResponse(message=f"Successfully deleted {delete_count} files")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user