forked from github/onyx
Compare commits
202 Commits
main
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d450555b6 | ||
|
|
c2758a28d5 | ||
|
|
5cda2e0173 | ||
|
|
9e885a68b3 | ||
|
|
376fc86b0c | ||
|
|
2eb1444d80 | ||
|
|
bd6ebe4718 | ||
|
|
691d63bc0f | ||
|
|
dfd4d9abef | ||
|
|
4cb39bc150 | ||
|
|
4e357478e0 | ||
|
|
b5b1b3287c | ||
|
|
2f58a972eb | ||
|
|
6b39d8eed9 | ||
|
|
f81c34d040 | ||
|
|
0771b1f476 | ||
|
|
eedd2ba3fe | ||
|
|
98554e5025 | ||
|
|
dcd2cad6b4 | ||
|
|
189f4bb071 | ||
|
|
7eeab8fb80 | ||
|
|
60f83dd0db | ||
|
|
2618602fd6 | ||
|
|
b80f96de85 | ||
|
|
74a15b2c01 | ||
|
|
408b80ce51 | ||
|
|
e82b68c1b0 | ||
|
|
af5eec648b | ||
|
|
d186c5e82e | ||
|
|
4420a50aed | ||
|
|
9caa6ea7ff | ||
|
|
8d7b217d33 | ||
|
|
57908769f1 | ||
|
|
600cec7c89 | ||
|
|
bb8ea536c4 | ||
|
|
f97869b91e | ||
|
|
aa5be56884 | ||
|
|
7580178c95 | ||
|
|
2e0bc8caf0 | ||
|
|
f9bd03c7f0 | ||
|
|
77466e1f2b | ||
|
|
8dd79345ed | ||
|
|
a049835c49 | ||
|
|
d186d8e8ed | ||
|
|
082897eb9b | ||
|
|
e38f79dec5 | ||
|
|
26e7bba25d | ||
|
|
3cde4ef77f | ||
|
|
f4d135d710 | ||
|
|
6094f70ac8 | ||
|
|
a90e58b39b | ||
|
|
e82e3141ed | ||
|
|
f8e9060bab | ||
|
|
24831fa1a1 | ||
|
|
f6a0e69b2a | ||
|
|
0394eaea7f | ||
|
|
898b8c316e | ||
|
|
4b0c6d1e54 | ||
|
|
da7dc33afa | ||
|
|
c558732ddd | ||
|
|
339ad9189b | ||
|
|
32d5e408b8 | ||
|
|
14ead457d9 | ||
|
|
458cd7e832 | ||
|
|
770a2692e9 | ||
|
|
5dd99b6acf | ||
|
|
6c7eb89374 | ||
|
|
fd11c16c6d | ||
|
|
11ec603c37 | ||
|
|
495d4cac44 | ||
|
|
fd2d74ae2e | ||
|
|
4c7a2e486b | ||
|
|
01e0ba6270 | ||
|
|
227dfc4a05 | ||
|
|
c3702b76b6 | ||
|
|
bb239d574c | ||
|
|
172e5f0e24 | ||
|
|
26b026fb88 | ||
|
|
870629e8a9 | ||
|
|
a547112321 | ||
|
|
da5a94815e | ||
|
|
e024472b74 | ||
|
|
e74855e633 | ||
|
|
e4c26a933d | ||
|
|
36c96f2d98 | ||
|
|
1ea94dcd8d | ||
|
|
2b1c5a0755 | ||
|
|
82b5f806ab | ||
|
|
6340c517d1 | ||
|
|
3baae2d4f0 | ||
|
|
d7c223ddd4 | ||
|
|
df4917243b | ||
|
|
a79ab713ce | ||
|
|
d1f7cee959 | ||
|
|
a3f41e20da | ||
|
|
458ed93da0 | ||
|
|
273d073bd7 | ||
|
|
9455c8e5ae | ||
|
|
d45d4389a0 | ||
|
|
bd901c0da1 | ||
|
|
2192605c95 | ||
|
|
d248d2f4e9 | ||
|
|
331c53871a | ||
|
|
f62d0d9144 | ||
|
|
427945e757 | ||
|
|
e55cdc6250 | ||
|
|
6a01db9ff2 | ||
|
|
82e9df5c22 | ||
|
|
16c2ef2852 | ||
|
|
224a70eea9 | ||
|
|
c457982120 | ||
|
|
0649748da2 | ||
|
|
ddceddaa28 | ||
|
|
c6733a5026 | ||
|
|
7db744a5de | ||
|
|
cd2a8b0def | ||
|
|
f15bc26cd6 | ||
|
|
65f35f0293 | ||
|
|
4e3e608249 | ||
|
|
719a092a12 | ||
|
|
6a8fde7eb1 | ||
|
|
4fdd0812a0 | ||
|
|
4913dc1e85 | ||
|
|
4a43a9642e | ||
|
|
cc48a0c38e | ||
|
|
01ccfd2df7 | ||
|
|
36d75786ee | ||
|
|
f9bc38ba65 | ||
|
|
3da283221d | ||
|
|
90568d3bbb | ||
|
|
7955ca938c | ||
|
|
f5d357eb28 | ||
|
|
d83f616214 | ||
|
|
275c1bec3d | ||
|
|
7d1ef912e8 | ||
|
|
2fe1d4c373 | ||
|
|
2396ad309e | ||
|
|
0b13ef963a | ||
|
|
83073f3ded | ||
|
|
439a27a775 | ||
|
|
91773a4789 | ||
|
|
185beca648 | ||
|
|
2dc564c8df | ||
|
|
b259f53972 | ||
|
|
f8beb08e2f | ||
|
|
83c88c7cf6 | ||
|
|
2372dd40e0 | ||
|
|
5cb6bafe81 | ||
|
|
a0309b31c7 | ||
|
|
0fd268dba7 | ||
|
|
f345da7487 | ||
|
|
f2dacf03f1 | ||
|
|
e0fef50cf0 | ||
|
|
6ba3eeefa5 | ||
|
|
aa158abaa9 | ||
|
|
255c2af1d6 | ||
|
|
9ece3b0310 | ||
|
|
9e3aca03a7 | ||
|
|
dbd5d4d8f1 | ||
|
|
cdb97c3ce4 | ||
|
|
f30ced31a9 | ||
|
|
6cc6c43234 | ||
|
|
224d934cf4 | ||
|
|
8ecdc61ad3 | ||
|
|
08161db7ea | ||
|
|
b139764631 | ||
|
|
2b23dbde8d | ||
|
|
2dec009d63 | ||
|
|
91eadae353 | ||
|
|
8bff616e27 | ||
|
|
2c049e170f | ||
|
|
23e6d7ef3c | ||
|
|
ed81e75edd | ||
|
|
de22fc3a58 | ||
|
|
009b7f60f1 | ||
|
|
9d997e20df | ||
|
|
e6423c4541 | ||
|
|
cb969ad06a | ||
|
|
c4076d16b6 | ||
|
|
04a607a718 | ||
|
|
c1e1aa9dfd | ||
|
|
1ed7abae6e | ||
|
|
cf4855822b | ||
|
|
e242b1319c | ||
|
|
eba4b6620e | ||
|
|
3534515e11 | ||
|
|
5602ff8666 | ||
|
|
2fc70781b4 | ||
|
|
f76b4dec4c | ||
|
|
a5a516fa8a | ||
|
|
811a198134 | ||
|
|
5867ab1d7d | ||
|
|
dd6653eb1f | ||
|
|
db457ef432 | ||
|
|
de7fe939b2 | ||
|
|
38114d9542 | ||
|
|
32f20f2e2e | ||
|
|
3dd27099f7 | ||
|
|
91c4d43a80 | ||
|
|
a63ba1bb03 | ||
|
|
7b6189e74c | ||
|
|
ba423e5773 |
19
.github/actions/custom-build-and-push/action.yml
vendored
19
.github/actions/custom-build-and-push/action.yml
vendored
@@ -35,6 +35,16 @@ inputs:
|
||||
cache-to:
|
||||
description: 'Cache destinations'
|
||||
required: false
|
||||
outputs:
|
||||
description: 'Output destinations'
|
||||
required: false
|
||||
provenance:
|
||||
description: 'Generate provenance attestation'
|
||||
required: false
|
||||
default: 'false'
|
||||
build-args:
|
||||
description: 'Build arguments'
|
||||
required: false
|
||||
retry-wait-time:
|
||||
description: 'Time to wait before attempt 2 in seconds'
|
||||
required: false
|
||||
@@ -62,6 +72,9 @@ runs:
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Wait before attempt 2
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
@@ -85,6 +98,9 @@ runs:
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Wait before attempt 3
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
|
||||
@@ -108,6 +124,9 @@ runs:
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Report failure
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success' && steps.buildx3.outcome != 'success'
|
||||
|
||||
5
.github/pull_request_template.md
vendored
5
.github/pull_request_template.md
vendored
@@ -6,9 +6,6 @@
|
||||
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
## Additional Options
|
||||
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
24
.github/workflows/check-lazy-imports.yml
vendored
Normal file
24
.github/workflows/check-lazy-imports.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Check Lazy Imports
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
check-lazy-imports:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Check lazy imports
|
||||
run: python3 backend/scripts/check_lazy_imports.py
|
||||
@@ -142,15 +142,25 @@ jobs:
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
# Security: Using pinned digest (0.65.0@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436)
|
||||
# Security: No Docker socket mount needed for remote registry scanning
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL onyxdotapp/onyx-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
trivyignores: ./backend/.trivyignore
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
--ignorefile /tmp/.trivyignore \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -139,12 +139,20 @@ jobs:
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -99,7 +99,7 @@ jobs:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
[runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-arm64"]
|
||||
env:
|
||||
PLATFORM_PAIR: linux-arm64
|
||||
steps:
|
||||
@@ -164,13 +164,20 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout: "10m"
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -150,12 +150,20 @@ jobs:
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
3
.github/workflows/helm-chart-releases.yml
vendored
3
.github/workflows/helm-chart-releases.yml
vendored
@@ -27,6 +27,7 @@ jobs:
|
||||
run: |
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add keda https://kedacore.github.io/charts
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
@@ -46,4 +47,4 @@ jobs:
|
||||
charts_dir: deployment/helm/charts
|
||||
branch: gh-pages
|
||||
commit_username: ${{ github.actor }}
|
||||
commit_email: ${{ github.actor }}@users.noreply.github.com
|
||||
commit_email: ${{ github.actor }}@users.noreply.github.com
|
||||
|
||||
124
.github/workflows/pr-backport-autotrigger.yml
vendored
124
.github/workflows/pr-backport-autotrigger.yml
vendored
@@ -1,124 +0,0 @@
|
||||
name: Backport on Merge
|
||||
|
||||
# Note this workflow does not trigger the builds, be sure to manually tag the branches to trigger the builds
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed] # Later we check for merge so only PRs that go in can get backported
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
actions: write
|
||||
|
||||
jobs:
|
||||
backport:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.YUHONG_GH_ACTIONS }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@onyx.app"
|
||||
git fetch --prune
|
||||
|
||||
- name: Check for Backport Checkbox
|
||||
id: checkbox-check
|
||||
run: |
|
||||
PR_BODY="${{ github.event.pull_request.body }}"
|
||||
if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then
|
||||
echo "backport=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "backport=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: List and sort release branches
|
||||
id: list-branches
|
||||
run: |
|
||||
git fetch --all --tags
|
||||
BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr)
|
||||
BETA=$(echo "$BRANCHES" | head -n 1)
|
||||
STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1)
|
||||
echo "beta=release/$BETA" >> $GITHUB_OUTPUT
|
||||
echo "stable=release/$STABLE" >> $GITHUB_OUTPUT
|
||||
# Fetch latest tags for beta and stable
|
||||
LATEST_BETA_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*-beta.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$" | grep -v -- "-cloud" | sort -Vr | head -n 1)
|
||||
LATEST_STABLE_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+$" | sort -Vr | head -n 1)
|
||||
|
||||
# Handle case where no beta tags exist
|
||||
if [[ -z "$LATEST_BETA_TAG" ]]; then
|
||||
NEW_BETA_TAG="v1.0.0-beta.1"
|
||||
else
|
||||
NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 "." $3 "-beta." ($NF+1)}')
|
||||
fi
|
||||
|
||||
# Increment latest stable tag
|
||||
NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
|
||||
echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Echo branch and tag information
|
||||
run: |
|
||||
echo "Beta branch: ${{ steps.list-branches.outputs.beta }}"
|
||||
echo "Stable branch: ${{ steps.list-branches.outputs.stable }}"
|
||||
echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}"
|
||||
echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}"
|
||||
echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}"
|
||||
echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}"
|
||||
|
||||
- name: Trigger Backport
|
||||
if: steps.checkbox-check.outputs.backport == 'true'
|
||||
run: |
|
||||
set -e
|
||||
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
|
||||
|
||||
# Echo the merge commit SHA
|
||||
echo "Merge commit SHA: ${{ github.event.pull_request.merge_commit_sha }}"
|
||||
|
||||
# Fetch all history for all branches and tags
|
||||
git fetch --prune
|
||||
|
||||
# Reset and prepare the beta branch
|
||||
git checkout ${{ steps.list-branches.outputs.beta }}
|
||||
echo "Last 5 commits on beta branch:"
|
||||
git log -n 5 --pretty=format:"%H"
|
||||
echo "" # Newline for formatting
|
||||
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to beta failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Create new beta branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
# Push the changes and tag to the beta branch using PAT
|
||||
git push origin ${{ steps.list-branches.outputs.beta }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
|
||||
# Reset and prepare the stable branch
|
||||
git checkout ${{ steps.list-branches.outputs.stable }}
|
||||
echo "Last 5 commits on stable branch:"
|
||||
git log -n 5 --pretty=format:"%H"
|
||||
echo "" # Newline for formatting
|
||||
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to stable failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Create new stable branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
# Push the changes and tag to the stable branch using PAT
|
||||
git push origin ${{ steps.list-branches.outputs.stable }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
@@ -21,6 +21,10 @@ env:
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -39,8 +43,8 @@ jobs:
|
||||
|
||||
external-dependency-unit-tests:
|
||||
needs: discover-test-dirs
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
# Use larger runner with more resources for Vespa
|
||||
runs-on: [runs-on, runner=16cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -49,6 +53,7 @@ jobs:
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
MODEL_SERVER_HOST: "disabled"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@@ -74,19 +79,30 @@ jobs:
|
||||
- name: Set up Standard Dependencies
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d minio relational_db cache index
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d minio relational_db cache index
|
||||
|
||||
- name: Wait for services
|
||||
run: |
|
||||
echo "Waiting for services to be ready..."
|
||||
sleep 30
|
||||
|
||||
# Wait for Vespa specifically
|
||||
echo "Waiting for Vespa to be ready..."
|
||||
timeout 300 bash -c 'until curl -f -s http://localhost:8081/ApplicationStatus > /dev/null 2>&1; do echo "Vespa not ready, waiting..."; sleep 10; done' || echo "Vespa timeout - continuing anyway"
|
||||
|
||||
echo "Services should be ready now"
|
||||
|
||||
- name: Run migrations
|
||||
run: |
|
||||
cd backend
|
||||
# Run migrations to head
|
||||
alembic upgrade head
|
||||
alembic heads --verbose
|
||||
|
||||
- name: Run Tests for ${{ matrix.test-dir }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
-n 8 \
|
||||
--dist loadfile \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
|
||||
168
.github/workflows/pr-helm-chart-testing.yml
vendored
168
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -53,27 +53,155 @@ jobs:
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.12.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=postgresql.enabled=false \
|
||||
--set=redis.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--debug --config ct.yaml
|
||||
run: |
|
||||
echo "=== Pre-install Cluster Status ==="
|
||||
kubectl get nodes -o wide
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get storageclass
|
||||
|
||||
- name: Add Helm repositories and update
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo update
|
||||
|
||||
- name: Pre-pull critical images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling critical images to avoid timeout ==="
|
||||
# Get kind cluster name
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
# Pre-pull images that are likely to be used
|
||||
echo "Pre-pulling PostgreSQL image..."
|
||||
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
|
||||
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
|
||||
|
||||
echo "Pre-pulling Redis image..."
|
||||
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
|
||||
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
|
||||
|
||||
echo "Pre-pulling Onyx images..."
|
||||
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
|
||||
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
|
||||
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Starting chart installation with monitoring ==="
|
||||
|
||||
# Function to monitor cluster state
|
||||
monitor_cluster() {
|
||||
while true; do
|
||||
echo "=== Cluster Status Check at $(date) ==="
|
||||
# Only show non-running pods to reduce noise
|
||||
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
|
||||
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
|
||||
echo "Non-running pods:"
|
||||
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
|
||||
else
|
||||
echo "All pods running successfully"
|
||||
fi
|
||||
# Only show recent events if there are issues
|
||||
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
|
||||
if [ -n "$RECENT_EVENTS" ]; then
|
||||
echo "Recent warnings/errors:"
|
||||
echo "$RECENT_EVENTS"
|
||||
fi
|
||||
sleep 60
|
||||
done
|
||||
}
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_cluster &
|
||||
MONITOR_PID=$!
|
||||
|
||||
# Set up cleanup
|
||||
cleanup() {
|
||||
echo "=== Cleaning up monitoring process ==="
|
||||
kill $MONITOR_PID 2>/dev/null || true
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
|
||||
}
|
||||
|
||||
# Trap cleanup on exit
|
||||
trap cleanup EXIT
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.primary.persistence.enabled=false \
|
||||
--set=redis.enabled=true \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docfetching.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_file_processing.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
|
||||
echo "=== Installation completed successfully ==="
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
|
||||
|
||||
- name: Cleanup on failure
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
|
||||
echo "=== Pod descriptions for debugging ==="
|
||||
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
560
.github/workflows/pr-integration-tests.yml
vendored
560
.github/workflows/pr-integration-tests.yml
vendored
@@ -11,6 +11,12 @@ on:
|
||||
- "release/**"
|
||||
|
||||
env:
|
||||
# Private Registry Configuration
|
||||
PRIVATE_REGISTRY: experimental-registry.blacksmith.sh:5000
|
||||
PRIVATE_REGISTRY_USERNAME: ${{ secrets.PRIVATE_REGISTRY_USERNAME }}
|
||||
PRIVATE_REGISTRY_PASSWORD: ${{ secrets.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
@@ -23,18 +29,38 @@ env:
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
discover-test-dirs:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
all_dirs=""
|
||||
for dir in $tests_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
|
||||
done
|
||||
for dir in $connector_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
|
||||
done
|
||||
|
||||
# Remove trailing comma and wrap in array
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
prepare-build:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -47,12 +73,12 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/ee.txt
|
||||
- run: |
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/ee.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
@@ -70,132 +96,157 @@ jobs:
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Download OpenAPI artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push integration test Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
[
|
||||
discover-test-dirs,
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
# Pull all images from registry in parallel
|
||||
echo "Pulling Docker images in parallel..."
|
||||
# Pull images from private registry
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
# Wait for all background jobs to complete
|
||||
wait
|
||||
echo "All Docker images pulled successfully"
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# Start containers for multi-tenant tests
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Waiting for 3 minutes to ensure API server is ready..."
|
||||
sleep 180
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e IMAGE_TAG=test \
|
||||
-e DEV_MODE=true \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
id: run_multitenant_tests
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
|
||||
echo "Multi-tenant integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All multi-tenant integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
# Re-tag to remove registry prefix for docker-compose
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
@@ -208,14 +259,23 @@ jobs:
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
docker logs -f onyx-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
@@ -251,52 +311,44 @@ jobs:
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
@@ -304,19 +356,19 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
@@ -324,4 +376,158 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
docker compose down -v
|
||||
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
wait
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
|
||||
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
- name: Wait for service to be ready (multi-tenant)
|
||||
run: |
|
||||
echo "Starting wait-for-service script for multi-tenant..."
|
||||
docker logs -f onyx-api_server-1 &
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error; retrying..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Running multi-tenant integration tests..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
-e SKIP_RESET=true \
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e IMAGE_TAG=test \
|
||||
-e DEV_MODE=true \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
|
||||
- name: Dump API server logs (multi-tenant)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml logs --no-color api_server > $GITHUB_WORKSPACE/api_server_multitenant.log || true
|
||||
|
||||
- name: Dump all-container logs (multi-tenant)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose-multitenant.log || true
|
||||
|
||||
- name: Upload logs (multi-tenant)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs-multitenant
|
||||
path: ${{ github.workspace }}/docker-compose-multitenant.log
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml down -v
|
||||
|
||||
required:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
needs: [integration-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const needs = ${{ toJSON(needs) }};
|
||||
const failed = Object.values(needs).some(n => n.result !== 'success');
|
||||
if (failed) {
|
||||
core.setFailed('One or more upstream jobs failed or were cancelled.');
|
||||
} else {
|
||||
core.notice('All required jobs succeeded.');
|
||||
}
|
||||
|
||||
373
.github/workflows/pr-mit-integration-tests.yml
vendored
373
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -5,12 +5,15 @@ concurrency:
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
# Private Registry Configuration
|
||||
PRIVATE_REGISTRY: experimental-registry.blacksmith.sh:5000
|
||||
PRIVATE_REGISTRY_USERNAME: ${{ secrets.PRIVATE_REGISTRY_USERNAME }}
|
||||
PRIVATE_REGISTRY_PASSWORD: ${{ secrets.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
@@ -23,21 +26,42 @@ env:
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
|
||||
jobs:
|
||||
integration-tests-mit:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
discover-test-dirs:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
all_dirs=""
|
||||
for dir in $tests_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
|
||||
done
|
||||
for dir in $connector_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
|
||||
done
|
||||
|
||||
# Remove trailing comma and wrap in array
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
prepare-build:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -46,7 +70,9 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
- run: |
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
@@ -67,72 +93,158 @@ jobs:
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Download OpenAPI artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push integration test Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
[
|
||||
discover-test-dirs,
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
# See https://docs.blacksmith.sh/blacksmith-runners/overview
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
# Pull all images from registry in parallel
|
||||
echo "Pulling Docker images in parallel..."
|
||||
# Pull images from private registry
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
# Wait for all background jobs to complete
|
||||
wait
|
||||
echo "All Docker images pulled successfully"
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
# Re-tag to remove registry prefix for docker-compose
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
@@ -143,14 +255,23 @@ jobs:
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
docker logs -f onyx-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
@@ -187,51 +308,44 @@ jobs:
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
@@ -239,19 +353,19 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
@@ -259,4 +373,21 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
docker compose down -v
|
||||
|
||||
|
||||
required:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
needs: [integration-tests-mit]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const needs = ${{ toJSON(needs) }};
|
||||
const failed = Object.values(needs).some(n => n.result !== 'success');
|
||||
if (failed) {
|
||||
core.setFailed('One or more upstream jobs failed or were cancelled.');
|
||||
} else {
|
||||
core.notice('All required jobs succeeded.');
|
||||
}
|
||||
|
||||
271
.github/workflows/pr-playwright-tests.yml
vendored
271
.github/workflows/pr-playwright-tests.yml
vendored
@@ -6,44 +6,171 @@ concurrency:
|
||||
on: push
|
||||
|
||||
env:
|
||||
# AWS ECR Configuration
|
||||
AWS_REGION: ${{ secrets.AWS_REGION || 'us-west-2' }}
|
||||
ECR_REGISTRY: ${{ secrets.ECR_REGISTRY }}
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_ECR }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_ECR }}
|
||||
BUILDX_NO_DEFAULT_ATTESTATIONS: 1
|
||||
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
|
||||
# for federated slack tests
|
||||
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
|
||||
SLACK_CLIENT_SECRET: ${{ secrets.SLACK_CLIENT_SECRET }}
|
||||
|
||||
MOCK_LLM_RESPONSE: true
|
||||
PYTEST_PLAYWRIGHT_SKIP_INITIAL_RESET: true
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
name: Playwright Tests
|
||||
build-web-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Web Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
|
||||
playwright-tests:
|
||||
needs: [build-web-image, build-backend-image, build-model-server-image]
|
||||
name: Playwright Tests
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
# Pull all images from ECR in parallel
|
||||
echo "Pulling Docker images in parallel..."
|
||||
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }}) &
|
||||
|
||||
# Wait for all background jobs to complete
|
||||
wait
|
||||
echo "All Docker images pulled successfully"
|
||||
|
||||
# Re-tag with expected names for docker-compose
|
||||
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }} onyxdotapp/onyx-web-server:test
|
||||
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
@@ -58,79 +185,29 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
|
||||
- name: Build Web Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-web-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
- name: Create .env file for Docker Compose
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
AUTH_TYPE=basic
|
||||
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }}
|
||||
EXA_API_KEY=${{ env.EXA_API_KEY }}
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
IMAGE_TAG=test
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
GEN_AI_API_KEY=${{ secrets.OPENAI_API_KEY }} \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
docker logs -f onyx-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
@@ -160,22 +237,18 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run pytest playwright test init
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTEST_IGNORE_SKIP: true
|
||||
run: pytest -s tests/integration/tests/playwright/test_playwright.py
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./web
|
||||
run: npx playwright test
|
||||
run: |
|
||||
# Create test-results directory to ensure it exists for artifact upload
|
||||
mkdir -p test-results
|
||||
npx playwright test
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
# Chromatic automatically defaults to the test-results directory.
|
||||
# Replace with the path to your custom directory and adjust the CHROMATIC_ARCHIVE_LOCATION environment variable accordingly.
|
||||
name: test-results
|
||||
# Includes test results and debug screenshots
|
||||
name: playwright-test-results-${{ github.run_id }}
|
||||
path: ./web/test-results
|
||||
retention-days: 30
|
||||
|
||||
@@ -184,7 +257,7 @@ jobs:
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
docker compose logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
@@ -197,7 +270,7 @@ jobs:
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
docker compose down -v
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
|
||||
2
.github/workflows/pr-python-checks.yml
vendored
2
.github/workflows/pr-python-checks.yml
vendored
@@ -48,6 +48,8 @@ jobs:
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Run MyPy
|
||||
run: |
|
||||
|
||||
@@ -96,6 +96,13 @@ env:
|
||||
TEAMS_DIRECTORY_ID: ${{ secrets.TEAMS_DIRECTORY_ID }}
|
||||
TEAMS_SECRET: ${{ secrets.TEAMS_SECRET }}
|
||||
|
||||
# Bitbucket
|
||||
BITBUCKET_WORKSPACE: ${{ secrets.BITBUCKET_WORKSPACE }}
|
||||
BITBUCKET_REPOSITORIES: ${{ secrets.BITBUCKET_REPOSITORIES }}
|
||||
BITBUCKET_PROJECTS: ${{ secrets.BITBUCKET_PROJECTS }}
|
||||
BITBUCKET_EMAIL: ${{ secrets.BITBUCKET_EMAIL }}
|
||||
BITBUCKET_API_TOKEN: ${{ secrets.BITBUCKET_API_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
6
.github/workflows/pr-python-model-tests.yml
vendored
6
.github/workflows/pr-python-model-tests.yml
vendored
@@ -77,7 +77,7 @@ jobs:
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
|
||||
docker compose -f docker-compose.model-server-test.yml up -d indexing_model_server
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
@@ -132,7 +132,7 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose -f docker-compose.model-server-test.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
@@ -145,5 +145,5 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v
|
||||
docker compose -f docker-compose.model-server-test.yml down -v
|
||||
|
||||
|
||||
2
.github/workflows/pr-python-tests.yml
vendored
2
.github/workflows/pr-python-tests.yml
vendored
@@ -31,12 +31,14 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
|
||||
14
.gitignore
vendored
14
.gitignore
vendored
@@ -17,12 +17,26 @@ backend/tests/regression/answer_quality/test_data.json
|
||||
backend/tests/regression/search_quality/eval-*
|
||||
backend/tests/regression/search_quality/search_eval_config.yaml
|
||||
backend/tests/regression/search_quality/*.json
|
||||
backend/onyx/evals/data/
|
||||
*.log
|
||||
|
||||
# secret files
|
||||
.env
|
||||
jira_test_env
|
||||
settings.json
|
||||
|
||||
# others
|
||||
/deployment/data/nginx/app.conf
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
*.egg-info
|
||||
|
||||
# Local .terraform directories
|
||||
**/.terraform/*
|
||||
|
||||
# Local .tfstate files
|
||||
*.tfstate
|
||||
*.tfstate.*
|
||||
|
||||
# Local .terraform.lock.hcl file
|
||||
.terraform.lock.hcl
|
||||
|
||||
8
.mcp.json.template
Normal file
8
.mcp.json.template
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"onyx-mcp": {
|
||||
"type": "http",
|
||||
"url": "http://localhost:8000/mcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,15 @@ repos:
|
||||
additional_dependencies:
|
||||
- prettier
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: check-lazy-imports
|
||||
name: Check lazy imports are not directly imported
|
||||
entry: python3 backend/scripts/check_lazy_imports.py
|
||||
language: system
|
||||
files: ^backend/.*\.py$
|
||||
pass_filenames: false
|
||||
|
||||
# We would like to have a mypy pre-commit hook, but due to the fact that
|
||||
# pre-commit runs in it's own isolated environment, we would need to install
|
||||
# and keep in sync all dependencies so mypy has access to the appropriate type
|
||||
|
||||
19
.vscode/env_template.txt
vendored
19
.vscode/env_template.txt
vendored
@@ -10,7 +10,7 @@ SKIP_WARM_UP=True
|
||||
|
||||
# Always keep these on for Dev
|
||||
# Logs all model prompts to stdout
|
||||
LOG_DANSWER_MODEL_INTERACTIONS=True
|
||||
LOG_ONYX_MODEL_INTERACTIONS=True
|
||||
# More verbose logging
|
||||
LOG_LEVEL=debug
|
||||
|
||||
@@ -23,6 +23,9 @@ DISABLE_LLM_DOC_RELEVANCE=False
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
OPENID_CONFIG_URL=<REPLACE THIS>
|
||||
SAML_CONF_DIR=/<ABSOLUTE PATH TO ONYX>/onyx/backend/ee/onyx/configs/saml_config
|
||||
|
||||
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
|
||||
REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
@@ -36,8 +39,8 @@ FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using DanswerBot
|
||||
#DANSWER_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
|
||||
#DANSWER_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
|
||||
#ONYX_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
|
||||
#ONYX_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
|
||||
|
||||
|
||||
# Python stuff
|
||||
@@ -46,7 +49,6 @@ PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY=<REPLACE THIS>
|
||||
EXA_API_KEY=<REPLACE THIS>
|
||||
|
||||
|
||||
@@ -65,3 +67,12 @@ S3_ENDPOINT_URL=http://localhost:9004
|
||||
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
|
||||
S3_AWS_ACCESS_KEY_ID=minioadmin
|
||||
S3_AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
|
||||
# Show extra/uncommon connectors
|
||||
SHOW_EXTRA_CONNECTORS=True
|
||||
|
||||
# Local langsmith tracing
|
||||
LANGSMITH_TRACING="true"
|
||||
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
|
||||
LANGSMITH_API_KEY=<REPLACE_THIS>
|
||||
LANGSMITH_PROJECT=<REPLACE_THIS>
|
||||
877
.vscode/launch.template.jsonc
vendored
877
.vscode/launch.template.jsonc
vendored
@@ -1,419 +1,468 @@
|
||||
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
|
||||
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": ["--- Individual ---"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": ["Web Server", "Model Server", "API Server"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": ["--- Individual ---"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": ["run", "dev"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": ["model_server.main:app", "--reload", "--port", "9000"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Model Server Console"
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": ["onyx.main:app", "--reload", "--port", "8080"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery primary Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery docfetching",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching,user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docfetching Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery docprocessing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docprocessing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docprocessing@%n",
|
||||
"-Q",
|
||||
"docprocessing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docprocessing Console",
|
||||
"justMyCode": false
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": ["Web Server", "Model Server", "API Server"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
"stopAll": true
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": ["run", "dev"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
"args": ["model_server.main:app", "--reload", "--port", "9000"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_containers.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
"consoleTitle": "Model Server Console"
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_ONYX_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
"args": ["onyx.main:app", "--reload", "--port", "8080"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
"consoleTitle": "API Server Console"
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
// ONYX_BOT_SLACK_APP_TOKEN & ONYX_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery primary Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery docfetching",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching,user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docfetching Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery docprocessing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docprocessing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docprocessing@%n",
|
||||
"-Q",
|
||||
"docprocessing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docprocessing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user file processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"--pool=threads",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user file processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a specific module/test to run or provide nothing to run all tests
|
||||
// "tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_containers.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Eval CLI",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/backend/onyx/evals/eval_cli.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
},
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": ["--verbose"],
|
||||
"consoleTitle": "Eval CLI Console"
|
||||
},
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_ONYX_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
// script to generate the openapi schema
|
||||
"name": "Onyx OpenAPI Schema Generator",
|
||||
@@ -426,10 +475,7 @@
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"--filename",
|
||||
"generated/openapi.json"
|
||||
]
|
||||
"args": ["--filename", "generated/openapi.json"]
|
||||
},
|
||||
{
|
||||
// script to debug multi tenant db issues
|
||||
@@ -454,13 +500,12 @@
|
||||
"generated/tenants_by_num_docs.csv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Debug React Web App in Chrome",
|
||||
"type": "chrome",
|
||||
"request": "launch",
|
||||
"url": "http://localhost:3000",
|
||||
"webRoot": "${workspaceFolder}/web"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
{
|
||||
"name": "Debug React Web App in Chrome",
|
||||
"type": "chrome",
|
||||
"request": "launch",
|
||||
"url": "http://localhost:3000",
|
||||
"webRoot": "${workspaceFolder}/web"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
295
AGENTS.md
Normal file
295
AGENTS.md
Normal file
@@ -0,0 +1,295 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to Codex when working with code in this repository.
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
- To connect to the Postgres database, use: `docker exec -it onyx-relational_db-1 psql -U postgres -c "<SQL>"`
|
||||
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
|
||||
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
|
||||
outside of those directories.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**Onyx** (formerly Danswer) is an open-source Gen-AI and Enterprise Search platform that connects to company documents, apps, and people. It features a modular architecture with both Community Edition (MIT licensed) and Enterprise Edition offerings.
|
||||
|
||||
|
||||
### Background Workers (Celery)
|
||||
|
||||
Onyx uses Celery for asynchronous task processing with multiple specialized workers:
|
||||
|
||||
#### Worker Types
|
||||
|
||||
1. **Primary Worker** (`celery_app.py`)
|
||||
- Coordinates core background tasks and system-wide operations
|
||||
- Handles connector management, document sync, pruning, and periodic checks
|
||||
- Runs with 4 threads concurrency
|
||||
- Tasks: connector deletion, vespa sync, pruning, LLM model updates, user file sync
|
||||
|
||||
2. **Docfetching Worker** (`docfetching`)
|
||||
- Fetches documents from external data sources (connectors)
|
||||
- Spawns docprocessing tasks for each document batch
|
||||
- Implements watchdog monitoring for stuck connectors
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
3. **Docprocessing Worker** (`docprocessing`)
|
||||
- Processes fetched documents through the indexing pipeline:
|
||||
- Upserts documents to PostgreSQL
|
||||
- Chunks documents and adds contextual information
|
||||
- Embeds chunks via model server
|
||||
- Writes chunks to Vespa vector database
|
||||
- Updates document metadata
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Primary task: document pruning operations
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
- Handles Knowledge Graph processing and clustering
|
||||
- Builds relationships between documents
|
||||
- Runs clustering algorithms
|
||||
- Configurable concurrency
|
||||
|
||||
7. **Monitoring Worker** (`monitoring`)
|
||||
- System health monitoring and metrics collection
|
||||
- Monitors Celery queues, process memory, and system status
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
- Indexing checks (every 15 seconds)
|
||||
- Connector deletion checks (every 20 seconds)
|
||||
- Vespa sync checks (every 20 seconds)
|
||||
- Pruning checks (every 20 seconds)
|
||||
- KG processing (every 60 seconds)
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
via Celery Beat.
|
||||
- **Task Prioritization**: High, Medium, Low priority queues
|
||||
- **Monitoring**: Built-in heartbeat and liveness checking
|
||||
- **Failure Handling**: Automatic retry and failure recovery mechanisms
|
||||
- **Redis Coordination**: Inter-process communication via Redis
|
||||
- **PostgreSQL State**: Task state and metadata stored in PostgreSQL
|
||||
|
||||
|
||||
#### Important Notes
|
||||
|
||||
**Defining Tasks**:
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
function.
|
||||
|
||||
**Testing Updates**:
|
||||
If you make any updates to a celery worker and you want to test these changes, you will need
|
||||
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
|
||||
|
||||
### Code Quality
|
||||
```bash
|
||||
# Install and run pre-commit hooks
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
NOTE: Always make sure everything is strictly typed (both in Python and Typescript).
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Technology Stack
|
||||
- **Backend**: Python 3.11, FastAPI, SQLAlchemy, Alembic, Celery
|
||||
- **Frontend**: Next.js 15+, React 18, TypeScript, Tailwind CSS
|
||||
- **Database**: PostgreSQL with Redis caching
|
||||
- **Search**: Vespa vector database
|
||||
- **Auth**: OAuth2, SAML, multi-provider support
|
||||
- **AI/ML**: LangChain, LiteLLM, multiple embedding models
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
backend/
|
||||
├── onyx/
|
||||
│ ├── auth/ # Authentication & authorization
|
||||
│ ├── chat/ # Chat functionality & LLM interactions
|
||||
│ ├── connectors/ # Data source connectors
|
||||
│ ├── db/ # Database models & operations
|
||||
│ ├── document_index/ # Vespa integration
|
||||
│ ├── federated_connectors/ # External search connectors
|
||||
│ ├── llm/ # LLM provider integrations
|
||||
│ └── server/ # API endpoints & routers
|
||||
├── ee/ # Enterprise Edition features
|
||||
├── alembic/ # Database migrations
|
||||
└── tests/ # Test suites
|
||||
|
||||
web/
|
||||
├── src/app/ # Next.js app router pages
|
||||
├── src/components/ # Reusable React components
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
```bash
|
||||
# Standard migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Multi-tenant (Enterprise)
|
||||
alembic -n schema_private upgrade head
|
||||
```
|
||||
|
||||
### Creating Migrations
|
||||
```bash
|
||||
# Auto-generate migration
|
||||
alembic revision --autogenerate -m "description"
|
||||
|
||||
# Multi-tenant migration
|
||||
alembic -n schema_private revision --autogenerate -m "description"
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
There are 4 main types of tests within Onyx:
|
||||
|
||||
### Unit Tests
|
||||
These should not assume any Onyx/external services are available to be called.
|
||||
Interactions with the outside world should be mocked using `unittest.mock`. Generally, only
|
||||
write these for complex, isolated modules e.g. `citation_processing.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests
|
||||
These tests assume that all external dependencies of Onyx are available and callable (e.g. Postgres, Redis,
|
||||
MinIO/S3, Vespa are running + OpenAI can be called + any request to the internet is fine + etc.).
|
||||
|
||||
However, the actual Onyx containers are not running and with these tests we call the function to test directly.
|
||||
We can also mock components/calls at will.
|
||||
|
||||
The goal with these tests are to minimize mocking while giving some flexibility to mock things that are flakey,
|
||||
need strictly controlled behavior, or need to have their internal behavior validated (e.g. verify a function is called
|
||||
with certain args, something that would be impossible with proper integration tests).
|
||||
|
||||
A great example of this type of test is `backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
Standard integration tests. Every test in `backend/tests/integration` runs against a real Onyx deployment. We cannot
|
||||
mock anything in these tests. Prefer writing integration tests (or External Dependency Unit Tests if mocking/internal
|
||||
verification is necessary) over any other type of test.
|
||||
|
||||
Tests are parallelized at a directory level.
|
||||
|
||||
When writing integration tests, make sure to check the root `conftest.py` for useful fixtures + the `backend/tests/integration/common_utils` directory for utilities. Prefer (if one exists), calling the appropriate Manager
|
||||
class in the utils over directly calling the APIs with a library like `requests`. Prefer using fixtures rather than
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright (E2E) Tests
|
||||
These tests are an even more complete version of the Integration Tests mentioned above. Has all services of Onyx
|
||||
running, *including* the Web Server.
|
||||
|
||||
Use these tests for anything that requires significant frontend <-> backend coordination.
|
||||
|
||||
Tests are located at `web/tests/e2e`. Tests are written in TypeScript.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
to logs via the `backend/log/<service_name>_debug.log` file. All Onyx services (api_server, web_server, celery_X)
|
||||
will be tailing their logs to this file.
|
||||
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Never commit API keys or secrets to repository
|
||||
- Use encrypted credential storage for connector credentials
|
||||
- Follow RBAC patterns for new features
|
||||
- Implement proper input validation with Pydantic models
|
||||
- Use parameterized queries to prevent SQL injection
|
||||
|
||||
## AI/LLM Integration
|
||||
|
||||
- Multiple LLM providers supported via LiteLLM
|
||||
- Configurable models per feature (chat, search, embeddings)
|
||||
- Streaming support for real-time responses
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
**Issues to Address**
|
||||
What the change is meant to do.
|
||||
|
||||
**Important Notes**
|
||||
Things you come across in your research that are important to the implementation.
|
||||
|
||||
**Implementation strategy**
|
||||
How you are going to make the changes happen. High level approach.
|
||||
|
||||
**Tests**
|
||||
What unit (use rarely), external dependency unit, integration, and playwright tests you plan to write to
|
||||
verify the correct behavior. Don't overtest. Usually, a given change only needs one type of test.
|
||||
|
||||
Do NOT include these: *Timeline*, *Rollback plan*
|
||||
|
||||
This is a minimal list - feel free to include more. Do NOT write code as part of your plan.
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
295
CLAUDE.md
Normal file
295
CLAUDE.md
Normal file
@@ -0,0 +1,295 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
- To connect to the Postgres database, use: `docker exec -it onyx-relational_db-1 psql -U postgres -c "<SQL>"`
|
||||
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
|
||||
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
|
||||
outside of those directories.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**Onyx** (formerly Danswer) is an open-source Gen-AI and Enterprise Search platform that connects to company documents, apps, and people. It features a modular architecture with both Community Edition (MIT licensed) and Enterprise Edition offerings.
|
||||
|
||||
|
||||
### Background Workers (Celery)
|
||||
|
||||
Onyx uses Celery for asynchronous task processing with multiple specialized workers:
|
||||
|
||||
#### Worker Types
|
||||
|
||||
1. **Primary Worker** (`celery_app.py`)
|
||||
- Coordinates core background tasks and system-wide operations
|
||||
- Handles connector management, document sync, pruning, and periodic checks
|
||||
- Runs with 4 threads concurrency
|
||||
- Tasks: connector deletion, vespa sync, pruning, LLM model updates, user file sync
|
||||
|
||||
2. **Docfetching Worker** (`docfetching`)
|
||||
- Fetches documents from external data sources (connectors)
|
||||
- Spawns docprocessing tasks for each document batch
|
||||
- Implements watchdog monitoring for stuck connectors
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
3. **Docprocessing Worker** (`docprocessing`)
|
||||
- Processes fetched documents through the indexing pipeline:
|
||||
- Upserts documents to PostgreSQL
|
||||
- Chunks documents and adds contextual information
|
||||
- Embeds chunks via model server
|
||||
- Writes chunks to Vespa vector database
|
||||
- Updates document metadata
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Primary task: document pruning operations
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
- Handles Knowledge Graph processing and clustering
|
||||
- Builds relationships between documents
|
||||
- Runs clustering algorithms
|
||||
- Configurable concurrency
|
||||
|
||||
7. **Monitoring Worker** (`monitoring`)
|
||||
- System health monitoring and metrics collection
|
||||
- Monitors Celery queues, process memory, and system status
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
- Indexing checks (every 15 seconds)
|
||||
- Connector deletion checks (every 20 seconds)
|
||||
- Vespa sync checks (every 20 seconds)
|
||||
- Pruning checks (every 20 seconds)
|
||||
- KG processing (every 60 seconds)
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
via Celery Beat.
|
||||
- **Task Prioritization**: High, Medium, Low priority queues
|
||||
- **Monitoring**: Built-in heartbeat and liveness checking
|
||||
- **Failure Handling**: Automatic retry and failure recovery mechanisms
|
||||
- **Redis Coordination**: Inter-process communication via Redis
|
||||
- **PostgreSQL State**: Task state and metadata stored in PostgreSQL
|
||||
|
||||
|
||||
#### Important Notes
|
||||
|
||||
**Defining Tasks**:
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
function.
|
||||
|
||||
**Testing Updates**:
|
||||
If you make any updates to a celery worker and you want to test these changes, you will need
|
||||
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
|
||||
|
||||
### Code Quality
|
||||
```bash
|
||||
# Install and run pre-commit hooks
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
NOTE: Always make sure everything is strictly typed (both in Python and Typescript).
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Technology Stack
|
||||
- **Backend**: Python 3.11, FastAPI, SQLAlchemy, Alembic, Celery
|
||||
- **Frontend**: Next.js 15+, React 18, TypeScript, Tailwind CSS
|
||||
- **Database**: PostgreSQL with Redis caching
|
||||
- **Search**: Vespa vector database
|
||||
- **Auth**: OAuth2, SAML, multi-provider support
|
||||
- **AI/ML**: LangChain, LiteLLM, multiple embedding models
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
backend/
|
||||
├── onyx/
|
||||
│ ├── auth/ # Authentication & authorization
|
||||
│ ├── chat/ # Chat functionality & LLM interactions
|
||||
│ ├── connectors/ # Data source connectors
|
||||
│ ├── db/ # Database models & operations
|
||||
│ ├── document_index/ # Vespa integration
|
||||
│ ├── federated_connectors/ # External search connectors
|
||||
│ ├── llm/ # LLM provider integrations
|
||||
│ └── server/ # API endpoints & routers
|
||||
├── ee/ # Enterprise Edition features
|
||||
├── alembic/ # Database migrations
|
||||
└── tests/ # Test suites
|
||||
|
||||
web/
|
||||
├── src/app/ # Next.js app router pages
|
||||
├── src/components/ # Reusable React components
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
```bash
|
||||
# Standard migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Multi-tenant (Enterprise)
|
||||
alembic -n schema_private upgrade head
|
||||
```
|
||||
|
||||
### Creating Migrations
|
||||
```bash
|
||||
# Auto-generate migration
|
||||
alembic revision --autogenerate -m "description"
|
||||
|
||||
# Multi-tenant migration
|
||||
alembic -n schema_private revision --autogenerate -m "description"
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
There are 4 main types of tests within Onyx:
|
||||
|
||||
### Unit Tests
|
||||
These should not assume any Onyx/external services are available to be called.
|
||||
Interactions with the outside world should be mocked using `unittest.mock`. Generally, only
|
||||
write these for complex, isolated modules e.g. `citation_processing.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests
|
||||
These tests assume that all external dependencies of Onyx are available and callable (e.g. Postgres, Redis,
|
||||
MinIO/S3, Vespa are running + OpenAI can be called + any request to the internet is fine + etc.).
|
||||
|
||||
However, the actual Onyx containers are not running and with these tests we call the function to test directly.
|
||||
We can also mock components/calls at will.
|
||||
|
||||
The goal with these tests are to minimize mocking while giving some flexibility to mock things that are flakey,
|
||||
need strictly controlled behavior, or need to have their internal behavior validated (e.g. verify a function is called
|
||||
with certain args, something that would be impossible with proper integration tests).
|
||||
|
||||
A great example of this type of test is `backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
Standard integration tests. Every test in `backend/tests/integration` runs against a real Onyx deployment. We cannot
|
||||
mock anything in these tests. Prefer writing integration tests (or External Dependency Unit Tests if mocking/internal
|
||||
verification is necessary) over any other type of test.
|
||||
|
||||
Tests are parallelized at a directory level.
|
||||
|
||||
When writing integration tests, make sure to check the root `conftest.py` for useful fixtures + the `backend/tests/integration/common_utils` directory for utilities. Prefer (if one exists), calling the appropriate Manager
|
||||
class in the utils over directly calling the APIs with a library like `requests`. Prefer using fixtures rather than
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright (E2E) Tests
|
||||
These tests are an even more complete version of the Integration Tests mentioned above. Has all services of Onyx
|
||||
running, *including* the Web Server.
|
||||
|
||||
Use these tests for anything that requires significant frontend <-> backend coordination.
|
||||
|
||||
Tests are located at `web/tests/e2e`. Tests are written in TypeScript.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
to logs via the `backend/log/<service_name>_debug.log` file. All Onyx services (api_server, web_server, celery_X)
|
||||
will be tailing their logs to this file.
|
||||
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Never commit API keys or secrets to repository
|
||||
- Use encrypted credential storage for connector credentials
|
||||
- Follow RBAC patterns for new features
|
||||
- Implement proper input validation with Pydantic models
|
||||
- Use parameterized queries to prevent SQL injection
|
||||
|
||||
## AI/LLM Integration
|
||||
|
||||
- Multiple LLM providers supported via LiteLLM
|
||||
- Configurable models per feature (chat, search, embeddings)
|
||||
- Streaming support for real-time responses
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
**Issues to Address**
|
||||
What the change is meant to do.
|
||||
|
||||
**Important Notes**
|
||||
Things you come across in your research that are important to the implementation.
|
||||
|
||||
**Implementation strategy**
|
||||
How you are going to make the changes happen. High level approach.
|
||||
|
||||
**Tests**
|
||||
What unit (use rarely), external dependency unit, integration, and playwright tests you plan to write to
|
||||
verify the correct behavior. Don't overtest. Usually, a given change only needs one type of test.
|
||||
|
||||
Do NOT include these: *Timeline*, *Rollback plan*
|
||||
|
||||
This is a minimal list - feel free to include more. Do NOT write code as part of your plan.
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
@@ -84,10 +84,6 @@ python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> This virtual environment MUST NOT be set up WITHIN the onyx directory if you plan on using mypy within certain IDEs.
|
||||
> For simplicity, we recommend setting up the virtual environment outside of the onyx directory.
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
|
||||
```bash
|
||||
@@ -103,10 +99,10 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r onyx/backend/requirements/default.txt
|
||||
pip install -r onyx/backend/requirements/dev.txt
|
||||
pip install -r onyx/backend/requirements/ee.txt
|
||||
pip install -r onyx/backend/requirements/model_server.txt
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/ee.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
@@ -175,7 +171,7 @@ You will need Docker installed to run these containers.
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache minio
|
||||
docker compose up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
@@ -257,7 +253,7 @@ You can run the full Onyx application stack from pre-built images including all
|
||||
Navigate to `onyx/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Onyx.
|
||||
@@ -265,7 +261,7 @@ After Docker pulls and starts these containers, navigate to `http://localhost:30
|
||||
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes like so:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d --build
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ This guide explains how to set up and use VSCode's debugging capabilities with t
|
||||
## Initial Setup
|
||||
|
||||
1. **Environment Setup**:
|
||||
- Copy `.vscode/.env.template` to `.vscode/.env`
|
||||
- Copy `.vscode/env_template.txt` to `.vscode/.env`
|
||||
- Fill in the necessary environment variables in `.vscode/.env`
|
||||
2. **launch.json**:
|
||||
- Copy `.vscode/launch.template.jsonc` to `.vscode/launch.json`
|
||||
@@ -17,10 +17,9 @@ Before starting, make sure the Docker Daemon is running.
|
||||
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
|
||||
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
|
||||
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
|
||||
4. CD into web, run "npm i" followed by npm run dev.
|
||||
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
7. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
## Features
|
||||
|
||||
|
||||
134
README.md
134
README.md
@@ -1,117 +1,103 @@
|
||||
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
|
||||
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
<p align="center">Open Source Gen-AI + Enterprise Search.</p>
|
||||
<p align="center">Open Source AI Platform</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
|
||||
</a>
|
||||
<a href="https://github.com/onyx-dot-app/onyx/blob/main/README.md" target="_blank">
|
||||
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
|
||||
</a>
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/website?url=https://www.onyx.app&up_message=visit&up_color=blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://github.com/onyx-dot-app/onyx/blob/main/LICENSE" target="_blank">
|
||||
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
|
||||
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
|
||||
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
|
||||
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
|
||||
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
|
||||
|
||||
|
||||
<h3>Feature Highlights</h3>
|
||||
**[Onyx](https://www.onyx.app/)** is a feature-rich, self-hostable Chat UI that works with any LLM. It is easy to deploy and can run in a completely airgapped environment.
|
||||
|
||||
**Deep research over your team's knowledge:**
|
||||
Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep Research, Connectors to 40+ knowledge sources, and more.
|
||||
|
||||
https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95d0-4fb5-8650-a396e05e0a32.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk5Mjg2MzYsIm5iZiI6MTczOTkyODMzNiwicGF0aCI6Ii8zMjUyMDc2OS80MTQ1MDkzMTItNDgzOTJlODMtOTVkMC00ZmI1LTg2NTAtYTM5NmUwNWUwYTMyLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE5VDAxMjUzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMzk5Njg2Y2Y5YjFmNDNiYTQ2YzM5ZTg5YWJiYTU2NWMyY2YwNmUyODE2NWUxMDRiMWQxZWJmODI4YTA0MTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.a9D8A0sgKE9AoaoE-mfFbJ6_OKYeqaf7TZ4Han2JfW8
|
||||
> [!TIP]
|
||||
> Run Onyx with one command (or see deployment section below):
|
||||
> ```
|
||||
> curl -fsSL https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.sh > install.sh && chmod +x install.sh && ./install.sh
|
||||
> ```
|
||||
|
||||
|
||||
**Use Onyx as a secure AI Chat with any LLM:**
|
||||
****
|
||||
|
||||

|
||||
|
||||
|
||||
**Easily set up connectors to your apps:**
|
||||
|
||||

|
||||
## ⭐ Features
|
||||
- **🤖 Custom Agents:** Build AI Agents with unique instructions, knowledge and actions.
|
||||
- **🌍 Web Search:** Browse the web with Google PSE, Exa, and Serper as well as an in-house scraper or Firecrawl.
|
||||
- **🔍 RAG:** Best in class hybrid-search + knowledge graph for uploaded files and ingested documents from connectors.
|
||||
- **🔄 Connectors:** Pull knowledge, metadata, and access information from over 40 applications.
|
||||
- **🔬 Deep Research:** Get in depth answers with an agentic multi-step search.
|
||||
- **▶️ Actions & MCP:** Give AI Agents the ability to interact with external systems.
|
||||
- **💻 Code Interpreter:** Execute code to analyze data, render graphs and create files.
|
||||
- **🎨 Image Generation:** Generate images based on user prompts.
|
||||
- **👥 Collaboration:** Chat sharing, feedback gathering, user management, usage analytics, and more.
|
||||
|
||||
Onyx works with all LLMs (like OpenAI, Anthropic, Gemini, etc.) and self-hosted LLMs (like Ollama, vLLM, etc.)
|
||||
|
||||
To learn more about the features, check out our [documentation](https://docs.onyx.app/welcome)!
|
||||
|
||||
|
||||
**Access Onyx where your team already works:**
|
||||
|
||||

|
||||
## 🚀 Deployment
|
||||
Onyx supports deployments in Docker, Kubernetes, Terraform, along with guides for major cloud providers.
|
||||
|
||||
See guides below:
|
||||
- [Docker](https://docs.onyx.app/deployment/local/docker) or [Quickstart](https://docs.onyx.app/deployment/getting_started/quickstart) (best for most users)
|
||||
- [Kubernetes](https://docs.onyx.app/deployment/local/kubernetes) (best for large teams)
|
||||
- [Terraform](https://docs.onyx.app/deployment/local/terraform) (best for teams already using Terraform)
|
||||
- Cloud specific guides (best if specifically using [AWS EKS](https://docs.onyx.app/deployment/cloud/aws/eks), [Azure VMs](https://docs.onyx.app/deployment/cloud/azure), etc.)
|
||||
|
||||
> [!TIP]
|
||||
> **To try Onyx for free without deploying, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
|
||||
|
||||
## Deployment
|
||||
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
|
||||
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
## 🔍 Other Notable Benefits
|
||||
Onyx is built for teams of all sizes, from individual users to the largest global enterprises.
|
||||
|
||||
We also have built-in support for high-availability/scalable deployment on Kubernetes.
|
||||
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
|
||||
- **Enterprise Search**: far more than simple RAG, Onyx has custom indexing and retrieval that remains performant and accurate for scales of up to tens of millions of documents.
|
||||
- **Security**: SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- **Management UI**: different user roles such as basic, curator, and admin.
|
||||
- **Document Permissioning**: mirrors user access from external apps for RAG use cases.
|
||||
|
||||
|
||||
## 🔍 Other Notable Benefits of Onyx
|
||||
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
|
||||
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- Knowledge curation features like document-sets, query history, usage analytics, etc.
|
||||
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
|
||||
|
||||
|
||||
## 🚧 Roadmap
|
||||
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
|
||||
- Personalized Search
|
||||
- Organizational understanding and ability to locate and suggest experts from your team.
|
||||
- Code Search
|
||||
- SQL and Structured Query Language
|
||||
To see ongoing and upcoming projects, check out our [roadmap](https://github.com/orgs/onyx-dot-app/projects/2)!
|
||||
|
||||
|
||||
## 🔌 Connectors
|
||||
Keep knowledge and access up to sync across 40+ connectors:
|
||||
|
||||
- Google Drive
|
||||
- Confluence
|
||||
- Slack
|
||||
- Gmail
|
||||
- Salesforce
|
||||
- Microsoft Sharepoint
|
||||
- Github
|
||||
- Jira
|
||||
- Zendesk
|
||||
- Gong
|
||||
- Microsoft Teams
|
||||
- Dropbox
|
||||
- Local Files
|
||||
- Websites
|
||||
- And more ...
|
||||
|
||||
See the full list [here](https://docs.onyx.app/connectors).
|
||||
|
||||
|
||||
## 📚 Licensing
|
||||
There are two editions of Onyx:
|
||||
|
||||
- Onyx Community Edition (CE) is available freely under the MIT Expat license. Simply follow the Deployment guide above.
|
||||
- Onyx Community Edition (CE) is available freely under the MIT license.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
|
||||
For feature details, check out [our website](https://www.onyx.app/pricing).
|
||||
|
||||
To try the Onyx Enterprise Edition:
|
||||
1. Checkout [Onyx Cloud](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting the Enterprise Edition, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
|
||||
|
||||
|
||||
## 👪 Community
|
||||
Join our open source community on **[Discord](https://discord.gg/TDJ59cGV2X)**!
|
||||
|
||||
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@ ARG ONYX_VERSION=0.0.0-dev
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true"
|
||||
DO_NOT_TRACK="true" \
|
||||
PLAYWRIGHT_BROWSERS_PATH="/app/.cache/ms-playwright"
|
||||
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
@@ -116,6 +117,14 @@ COPY ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
chown -R onyx:onyx /app && \
|
||||
mkdir -p /var/log/onyx && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -9,11 +9,36 @@ visit https://github.com/onyx-dot-app/onyx."
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
HF_HOME=/app/.cache/huggingface
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN mkdir -p /app && \
|
||||
groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
chown -R onyx:onyx /app && \
|
||||
mkdir -p /var/log/onyx && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
# --- add toolchain needed for Rust/Python builds (fastuuid) ---
|
||||
ENV RUSTUP_HOME=/usr/local/rustup \
|
||||
CARGO_HOME=/usr/local/cargo \
|
||||
PATH=/usr/local/cargo/bin:$PATH
|
||||
|
||||
RUN set -eux; \
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
curl \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
# Install latest stable Rust (supports Cargo.lock v4)
|
||||
&& curl -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal --default-toolchain stable \
|
||||
&& rustc --version && cargo --version
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
@@ -38,9 +63,11 @@ snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
|
||||
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, don't overwrite it with the built in cache folder
|
||||
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
|
||||
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
|
||||
# it's preserved in order to combine with the user's cache contents
|
||||
RUN mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
|
||||
chown -R onyx:onyx /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -0,0 +1,389 @@
|
||||
"""Migration 2: User file data preparation and backfill
|
||||
|
||||
Revision ID: 0cd424f32b1d
|
||||
Revises: 9b66d3156fc6
|
||||
Create Date: 2025-09-22 09:44:42.727034
|
||||
|
||||
This migration populates the new columns added in migration 1.
|
||||
It prepares data for the UUID transition and relationship migration.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0cd424f32b1d"
|
||||
down_revision = "9b66d3156fc6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Populate new columns with data."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
# === Step 1: Populate user_file.new_id ===
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
has_new_id = "new_id" in user_file_columns
|
||||
|
||||
if has_new_id:
|
||||
logger.info("Populating user_file.new_id with UUIDs...")
|
||||
|
||||
# Count rows needing UUIDs
|
||||
null_count = bind.execute(
|
||||
text("SELECT COUNT(*) FROM user_file WHERE new_id IS NULL")
|
||||
).scalar_one()
|
||||
|
||||
if null_count > 0:
|
||||
logger.info(f"Generating UUIDs for {null_count} user_file records...")
|
||||
|
||||
# Populate in batches to avoid long locks
|
||||
batch_size = 10000
|
||||
total_updated = 0
|
||||
|
||||
while True:
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE user_file
|
||||
SET new_id = gen_random_uuid()
|
||||
WHERE new_id IS NULL
|
||||
AND id IN (
|
||||
SELECT id FROM user_file
|
||||
WHERE new_id IS NULL
|
||||
LIMIT :batch_size
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"batch_size": batch_size},
|
||||
)
|
||||
|
||||
updated = result.rowcount
|
||||
total_updated += updated
|
||||
|
||||
if updated < batch_size:
|
||||
break
|
||||
|
||||
logger.info(f" Updated {total_updated}/{null_count} records...")
|
||||
|
||||
logger.info(f"Generated UUIDs for {total_updated} user_file records")
|
||||
|
||||
# Verify all records have UUIDs
|
||||
remaining_null = bind.execute(
|
||||
text("SELECT COUNT(*) FROM user_file WHERE new_id IS NULL")
|
||||
).scalar_one()
|
||||
|
||||
if remaining_null > 0:
|
||||
raise Exception(
|
||||
f"Failed to populate all user_file.new_id values ({remaining_null} NULL)"
|
||||
)
|
||||
|
||||
# Lock down the column
|
||||
op.alter_column("user_file", "new_id", nullable=False)
|
||||
op.alter_column("user_file", "new_id", server_default=None)
|
||||
logger.info("Locked down user_file.new_id column")
|
||||
|
||||
# === Step 2: Populate persona__user_file.user_file_id_uuid ===
|
||||
persona_user_file_columns = [
|
||||
col["name"] for col in inspector.get_columns("persona__user_file")
|
||||
]
|
||||
|
||||
if has_new_id and "user_file_id_uuid" in persona_user_file_columns:
|
||||
logger.info("Populating persona__user_file.user_file_id_uuid...")
|
||||
|
||||
# Count rows needing update
|
||||
null_count = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT COUNT(*) FROM persona__user_file
|
||||
WHERE user_file_id IS NOT NULL AND user_file_id_uuid IS NULL
|
||||
"""
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if null_count > 0:
|
||||
logger.info(f"Updating {null_count} persona__user_file records...")
|
||||
|
||||
# Update in batches
|
||||
batch_size = 10000
|
||||
total_updated = 0
|
||||
|
||||
while True:
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE persona__user_file p
|
||||
SET user_file_id_uuid = uf.new_id
|
||||
FROM user_file uf
|
||||
WHERE p.user_file_id = uf.id
|
||||
AND p.user_file_id_uuid IS NULL
|
||||
AND p.persona_id IN (
|
||||
SELECT persona_id
|
||||
FROM persona__user_file
|
||||
WHERE user_file_id_uuid IS NULL
|
||||
LIMIT :batch_size
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"batch_size": batch_size},
|
||||
)
|
||||
|
||||
updated = result.rowcount
|
||||
total_updated += updated
|
||||
|
||||
if updated < batch_size:
|
||||
break
|
||||
|
||||
logger.info(f" Updated {total_updated}/{null_count} records...")
|
||||
|
||||
logger.info(f"Updated {total_updated} persona__user_file records")
|
||||
|
||||
# Verify all records are populated
|
||||
remaining_null = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT COUNT(*) FROM persona__user_file
|
||||
WHERE user_file_id IS NOT NULL AND user_file_id_uuid IS NULL
|
||||
"""
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if remaining_null > 0:
|
||||
raise Exception(
|
||||
f"Failed to populate all persona__user_file.user_file_id_uuid values ({remaining_null} NULL)"
|
||||
)
|
||||
|
||||
op.alter_column("persona__user_file", "user_file_id_uuid", nullable=False)
|
||||
logger.info("Locked down persona__user_file.user_file_id_uuid column")
|
||||
|
||||
# === Step 3: Create user_project records from chat_folder ===
|
||||
if "chat_folder" in inspector.get_table_names():
|
||||
logger.info("Creating user_project records from chat_folder...")
|
||||
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO user_project (user_id, name)
|
||||
SELECT cf.user_id, cf.name
|
||||
FROM chat_folder cf
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM user_project up
|
||||
WHERE up.user_id = cf.user_id AND up.name = cf.name
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Created {result.rowcount} user_project records from chat_folder")
|
||||
|
||||
# === Step 4: Populate chat_session.project_id ===
|
||||
chat_session_columns = [
|
||||
col["name"] for col in inspector.get_columns("chat_session")
|
||||
]
|
||||
|
||||
if "folder_id" in chat_session_columns and "project_id" in chat_session_columns:
|
||||
logger.info("Populating chat_session.project_id...")
|
||||
|
||||
# Count sessions needing update
|
||||
null_count = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT COUNT(*) FROM chat_session
|
||||
WHERE project_id IS NULL AND folder_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if null_count > 0:
|
||||
logger.info(f"Updating {null_count} chat_session records...")
|
||||
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE chat_session cs
|
||||
SET project_id = up.id
|
||||
FROM chat_folder cf
|
||||
JOIN user_project up ON up.user_id = cf.user_id AND up.name = cf.name
|
||||
WHERE cs.folder_id = cf.id AND cs.project_id IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Updated {result.rowcount} chat_session records")
|
||||
|
||||
# Verify all records are populated
|
||||
remaining_null = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT COUNT(*) FROM chat_session
|
||||
WHERE project_id IS NULL AND folder_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if remaining_null > 0:
|
||||
logger.warning(
|
||||
f"Warning: {remaining_null} chat_session records could not be mapped to projects"
|
||||
)
|
||||
|
||||
# === Step 5: Update plaintext FileRecord IDs/display names to UUID scheme ===
|
||||
# Prior to UUID migration, plaintext cache files were stored with file_id like 'plain_text_<int_id>'.
|
||||
# After migration, we use 'plaintext_<uuid>' (note the name change to 'plaintext_').
|
||||
# This step remaps existing FileRecord rows to the new naming while preserving object_key/bucket.
|
||||
logger.info("Updating plaintext FileRecord ids and display names to UUID scheme...")
|
||||
|
||||
# Count legacy plaintext records that can be mapped to UUID user_file ids
|
||||
count_query = text(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM file_record fr
|
||||
JOIN user_file uf ON fr.file_id = CONCAT('plaintext_', uf.id::text)
|
||||
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
|
||||
"""
|
||||
)
|
||||
legacy_count = bind.execute(count_query).scalar_one()
|
||||
|
||||
if legacy_count and legacy_count > 0:
|
||||
logger.info(f"Found {legacy_count} legacy plaintext file records to update")
|
||||
|
||||
# Update display_name first for readability (safe regardless of rename)
|
||||
bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE file_record fr
|
||||
SET display_name = CONCAT('Plaintext for user file ', uf.new_id::text)
|
||||
FROM user_file uf
|
||||
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
|
||||
AND fr.file_id = CONCAT('plaintext_', uf.id::text)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Remap file_id from 'plaintext_<int>' -> 'plaintext_<uuid>' using transitional new_id
|
||||
# Use a single UPDATE ... WHERE file_id LIKE 'plain_text_%'
|
||||
# and ensure it aligns to existing user_file ids to avoid renaming unrelated rows
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE file_record fr
|
||||
SET file_id = CONCAT('plaintext_', uf.new_id::text)
|
||||
FROM user_file uf
|
||||
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
|
||||
AND fr.file_id = CONCAT('plaintext_', uf.id::text)
|
||||
"""
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Updated {result.rowcount} plaintext file_record ids to UUID scheme"
|
||||
)
|
||||
|
||||
# === Step 6: Ensure document_id_migrated default TRUE and backfill existing FALSE ===
|
||||
# New records should default to migrated=True so the migration task won't run for them.
|
||||
# Existing rows that had a legacy document_id should be marked as not migrated to be processed.
|
||||
|
||||
# Backfill existing records: if document_id is not null, set to FALSE
|
||||
bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE user_file
|
||||
SET document_id_migrated = FALSE
|
||||
WHERE document_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# === Step 7: Backfill user_file.status from index_attempt ===
|
||||
logger.info("Backfilling user_file.status from index_attempt...")
|
||||
|
||||
# Update user_file status based on latest index attempt
|
||||
# Using CTEs instead of temp tables for asyncpg compatibility
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
WITH latest_attempt AS (
|
||||
SELECT DISTINCT ON (ia.connector_credential_pair_id)
|
||||
ia.connector_credential_pair_id,
|
||||
ia.status
|
||||
FROM index_attempt ia
|
||||
ORDER BY ia.connector_credential_pair_id, ia.time_updated DESC
|
||||
),
|
||||
uf_to_ccp AS (
|
||||
SELECT DISTINCT uf.id AS uf_id, ccp.id AS cc_pair_id
|
||||
FROM user_file uf
|
||||
JOIN document_by_connector_credential_pair dcc
|
||||
ON dcc.id = REPLACE(uf.document_id, 'USER_FILE_CONNECTOR__', 'FILE_CONNECTOR__')
|
||||
JOIN connector_credential_pair ccp
|
||||
ON ccp.connector_id = dcc.connector_id
|
||||
AND ccp.credential_id = dcc.credential_id
|
||||
)
|
||||
UPDATE user_file uf
|
||||
SET status = CASE
|
||||
WHEN la.status IN ('NOT_STARTED', 'IN_PROGRESS') THEN 'PROCESSING'
|
||||
WHEN la.status = 'SUCCESS' THEN 'COMPLETED'
|
||||
ELSE 'FAILED'
|
||||
END
|
||||
FROM uf_to_ccp ufc
|
||||
LEFT JOIN latest_attempt la
|
||||
ON la.connector_credential_pair_id = ufc.cc_pair_id
|
||||
WHERE uf.id = ufc.uf_id
|
||||
AND uf.status = 'PROCESSING'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Updated status for {result.rowcount} user_file records")
|
||||
|
||||
logger.info("Migration 2 (data preparation) completed successfully")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Reset populated data to allow clean downgrade of schema."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.info("Starting downgrade of data preparation...")
|
||||
|
||||
# Reset user_file columns to allow nulls before data removal
|
||||
if "user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
|
||||
if "new_id" in columns:
|
||||
op.alter_column(
|
||||
"user_file",
|
||||
"new_id",
|
||||
nullable=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
)
|
||||
# Optionally clear the data
|
||||
# bind.execute(text("UPDATE user_file SET new_id = NULL"))
|
||||
logger.info("Reset user_file.new_id to nullable")
|
||||
|
||||
# Reset persona__user_file.user_file_id_uuid
|
||||
if "persona__user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("persona__user_file")]
|
||||
|
||||
if "user_file_id_uuid" in columns:
|
||||
op.alter_column("persona__user_file", "user_file_id_uuid", nullable=True)
|
||||
# Optionally clear the data
|
||||
# bind.execute(text("UPDATE persona__user_file SET user_file_id_uuid = NULL"))
|
||||
logger.info("Reset persona__user_file.user_file_id_uuid to nullable")
|
||||
|
||||
# Note: We don't delete user_project records or reset chat_session.project_id
|
||||
# as these might be in use and can be handled by the schema downgrade
|
||||
|
||||
# Reset user_file.status to default
|
||||
if "user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
if "status" in columns:
|
||||
bind.execute(text("UPDATE user_file SET status = 'PROCESSING'"))
|
||||
logger.info("Reset user_file.status to default")
|
||||
|
||||
logger.info("Downgrade completed successfully")
|
||||
@@ -0,0 +1,261 @@
|
||||
"""Migration 3: User file relationship migration
|
||||
|
||||
Revision ID: 16c37a30adf2
|
||||
Revises: 0cd424f32b1d
|
||||
Create Date: 2025-09-22 09:47:34.175596
|
||||
|
||||
This migration converts folder-based relationships to project-based relationships.
|
||||
It migrates persona__user_folder to persona__user_file and populates project__user_file.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "16c37a30adf2"
|
||||
down_revision = "0cd424f32b1d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Migrate folder-based relationships to project-based relationships."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
# === Step 1: Migrate persona__user_folder to persona__user_file ===
|
||||
table_names = inspector.get_table_names()
|
||||
|
||||
if "persona__user_folder" in table_names and "user_file" in table_names:
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
has_new_id = "new_id" in user_file_columns
|
||||
|
||||
if has_new_id and "folder_id" in user_file_columns:
|
||||
logger.info(
|
||||
"Migrating persona__user_folder relationships to persona__user_file..."
|
||||
)
|
||||
|
||||
# Count relationships to migrate (asyncpg-compatible)
|
||||
count_query = text(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM (
|
||||
SELECT DISTINCT puf.persona_id, uf.id
|
||||
FROM persona__user_folder puf
|
||||
JOIN user_file uf ON uf.folder_id = puf.user_folder_id
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM persona__user_file p2
|
||||
WHERE p2.persona_id = puf.persona_id
|
||||
AND p2.user_file_id = uf.id
|
||||
)
|
||||
) AS distinct_pairs
|
||||
"""
|
||||
)
|
||||
to_migrate = bind.execute(count_query).scalar_one()
|
||||
|
||||
if to_migrate > 0:
|
||||
logger.info(f"Creating {to_migrate} persona-file relationships...")
|
||||
|
||||
# Migrate in batches to avoid memory issues
|
||||
batch_size = 10000
|
||||
total_inserted = 0
|
||||
|
||||
while True:
|
||||
# Insert batch directly using subquery (asyncpg compatible)
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO persona__user_file (persona_id, user_file_id, user_file_id_uuid)
|
||||
SELECT DISTINCT puf.persona_id, uf.id as file_id, uf.new_id
|
||||
FROM persona__user_folder puf
|
||||
JOIN user_file uf ON uf.folder_id = puf.user_folder_id
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM persona__user_file p2
|
||||
WHERE p2.persona_id = puf.persona_id
|
||||
AND p2.user_file_id = uf.id
|
||||
)
|
||||
LIMIT :batch_size
|
||||
"""
|
||||
),
|
||||
{"batch_size": batch_size},
|
||||
)
|
||||
|
||||
inserted = result.rowcount
|
||||
total_inserted += inserted
|
||||
|
||||
if inserted < batch_size:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f" Migrated {total_inserted}/{to_migrate} relationships..."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created {total_inserted} persona__user_file relationships"
|
||||
)
|
||||
|
||||
# === Step 2: Add foreign key for chat_session.project_id ===
|
||||
chat_session_fks = inspector.get_foreign_keys("chat_session")
|
||||
fk_exists = any(
|
||||
fk["name"] == "fk_chat_session_project_id" for fk in chat_session_fks
|
||||
)
|
||||
|
||||
if not fk_exists:
|
||||
logger.info("Adding foreign key constraint for chat_session.project_id...")
|
||||
op.create_foreign_key(
|
||||
"fk_chat_session_project_id",
|
||||
"chat_session",
|
||||
"user_project",
|
||||
["project_id"],
|
||||
["id"],
|
||||
)
|
||||
logger.info("Added foreign key constraint")
|
||||
|
||||
# === Step 3: Populate project__user_file from user_file.folder_id ===
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
has_new_id = "new_id" in user_file_columns
|
||||
|
||||
if has_new_id and "folder_id" in user_file_columns:
|
||||
logger.info("Populating project__user_file from folder relationships...")
|
||||
|
||||
# Count relationships to create
|
||||
count_query = text(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM user_file uf
|
||||
WHERE uf.folder_id IS NOT NULL
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM project__user_file puf
|
||||
WHERE puf.project_id = uf.folder_id
|
||||
AND puf.user_file_id = uf.new_id
|
||||
)
|
||||
"""
|
||||
)
|
||||
to_create = bind.execute(count_query).scalar_one()
|
||||
|
||||
if to_create > 0:
|
||||
logger.info(f"Creating {to_create} project-file relationships...")
|
||||
|
||||
# Insert in batches
|
||||
batch_size = 10000
|
||||
total_inserted = 0
|
||||
|
||||
while True:
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO project__user_file (project_id, user_file_id)
|
||||
SELECT uf.folder_id, uf.new_id
|
||||
FROM user_file uf
|
||||
WHERE uf.folder_id IS NOT NULL
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM project__user_file puf
|
||||
WHERE puf.project_id = uf.folder_id
|
||||
AND puf.user_file_id = uf.new_id
|
||||
)
|
||||
LIMIT :batch_size
|
||||
ON CONFLICT (project_id, user_file_id) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"batch_size": batch_size},
|
||||
)
|
||||
|
||||
inserted = result.rowcount
|
||||
total_inserted += inserted
|
||||
|
||||
if inserted < batch_size:
|
||||
break
|
||||
|
||||
logger.info(f" Created {total_inserted}/{to_create} relationships...")
|
||||
|
||||
logger.info(f"Created {total_inserted} project__user_file relationships")
|
||||
|
||||
# === Step 4: Create index on chat_session.project_id ===
|
||||
try:
|
||||
indexes = [ix.get("name") for ix in inspector.get_indexes("chat_session")]
|
||||
except Exception:
|
||||
indexes = []
|
||||
|
||||
if "ix_chat_session_project_id" not in indexes:
|
||||
logger.info("Creating index on chat_session.project_id...")
|
||||
op.create_index(
|
||||
"ix_chat_session_project_id", "chat_session", ["project_id"], unique=False
|
||||
)
|
||||
logger.info("Created index")
|
||||
|
||||
logger.info("Migration 3 (relationship migration) completed successfully")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove migrated relationships and constraints."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.info("Starting downgrade of relationship migration...")
|
||||
|
||||
# Drop index on chat_session.project_id
|
||||
try:
|
||||
indexes = [ix.get("name") for ix in inspector.get_indexes("chat_session")]
|
||||
if "ix_chat_session_project_id" in indexes:
|
||||
op.drop_index("ix_chat_session_project_id", "chat_session")
|
||||
logger.info("Dropped index on chat_session.project_id")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Drop foreign key constraint
|
||||
try:
|
||||
chat_session_fks = inspector.get_foreign_keys("chat_session")
|
||||
fk_exists = any(
|
||||
fk["name"] == "fk_chat_session_project_id" for fk in chat_session_fks
|
||||
)
|
||||
if fk_exists:
|
||||
op.drop_constraint(
|
||||
"fk_chat_session_project_id", "chat_session", type_="foreignkey"
|
||||
)
|
||||
logger.info("Dropped foreign key constraint on chat_session.project_id")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clear project__user_file relationships (but keep the table for migration 1 to handle)
|
||||
if "project__user_file" in inspector.get_table_names():
|
||||
result = bind.execute(text("DELETE FROM project__user_file"))
|
||||
logger.info(f"Cleared {result.rowcount} records from project__user_file")
|
||||
|
||||
# Remove migrated persona__user_file relationships
|
||||
# Only remove those that came from folder relationships
|
||||
if all(
|
||||
table in inspector.get_table_names()
|
||||
for table in ["persona__user_file", "persona__user_folder", "user_file"]
|
||||
):
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
if "folder_id" in user_file_columns:
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM persona__user_file puf
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM user_file uf
|
||||
JOIN persona__user_folder puf2
|
||||
ON puf2.user_folder_id = uf.folder_id
|
||||
WHERE puf.persona_id = puf2.persona_id
|
||||
AND puf.user_file_id = uf.id
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Removed {result.rowcount} migrated persona__user_file relationships"
|
||||
)
|
||||
|
||||
logger.info("Downgrade completed successfully")
|
||||
@@ -0,0 +1,218 @@
|
||||
"""Migration 6: User file schema cleanup
|
||||
|
||||
Revision ID: 2b75d0a8ffcb
|
||||
Revises: 3a78dba1080a
|
||||
Create Date: 2025-09-22 10:09:26.375377
|
||||
|
||||
This migration removes legacy columns and tables after data migration is complete.
|
||||
It should only be run after verifying all data has been successfully migrated.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2b75d0a8ffcb"
|
||||
down_revision = "3a78dba1080a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Remove legacy columns and tables."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.info("Starting schema cleanup...")
|
||||
|
||||
# === Step 1: Verify data migration is complete ===
|
||||
logger.info("Verifying data migration completion...")
|
||||
|
||||
# Check if any chat sessions still have folder_id references
|
||||
chat_session_columns = [
|
||||
col["name"] for col in inspector.get_columns("chat_session")
|
||||
]
|
||||
if "folder_id" in chat_session_columns:
|
||||
orphaned_count = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT COUNT(*) FROM chat_session
|
||||
WHERE folder_id IS NOT NULL AND project_id IS NULL
|
||||
"""
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if orphaned_count > 0:
|
||||
logger.warning(
|
||||
f"WARNING: {orphaned_count} chat_session records still have "
|
||||
f"folder_id without project_id. Proceeding anyway."
|
||||
)
|
||||
|
||||
# === Step 2: Drop chat_session.folder_id ===
|
||||
if "folder_id" in chat_session_columns:
|
||||
logger.info("Dropping chat_session.folder_id...")
|
||||
|
||||
# Drop foreign key constraint first
|
||||
op.execute(
|
||||
"ALTER TABLE chat_session DROP CONSTRAINT IF EXISTS chat_session_folder_fk"
|
||||
)
|
||||
|
||||
# Drop the column
|
||||
op.drop_column("chat_session", "folder_id")
|
||||
logger.info("Dropped chat_session.folder_id")
|
||||
|
||||
# === Step 3: Drop persona__user_folder table ===
|
||||
if "persona__user_folder" in inspector.get_table_names():
|
||||
logger.info("Dropping persona__user_folder table...")
|
||||
|
||||
# Check for any remaining data
|
||||
remaining = bind.execute(
|
||||
text("SELECT COUNT(*) FROM persona__user_folder")
|
||||
).scalar_one()
|
||||
|
||||
if remaining > 0:
|
||||
logger.warning(
|
||||
f"WARNING: Dropping persona__user_folder with {remaining} records"
|
||||
)
|
||||
|
||||
op.drop_table("persona__user_folder")
|
||||
logger.info("Dropped persona__user_folder table")
|
||||
|
||||
# === Step 4: Drop chat_folder table ===
|
||||
if "chat_folder" in inspector.get_table_names():
|
||||
logger.info("Dropping chat_folder table...")
|
||||
|
||||
# Check for any remaining data
|
||||
remaining = bind.execute(text("SELECT COUNT(*) FROM chat_folder")).scalar_one()
|
||||
|
||||
if remaining > 0:
|
||||
logger.warning(f"WARNING: Dropping chat_folder with {remaining} records")
|
||||
|
||||
op.drop_table("chat_folder")
|
||||
logger.info("Dropped chat_folder table")
|
||||
|
||||
# === Step 5: Drop user_file legacy columns ===
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
|
||||
# Drop folder_id
|
||||
if "folder_id" in user_file_columns:
|
||||
logger.info("Dropping user_file.folder_id...")
|
||||
op.drop_column("user_file", "folder_id")
|
||||
logger.info("Dropped user_file.folder_id")
|
||||
|
||||
# Drop cc_pair_id (already handled in migration 5, but be sure)
|
||||
if "cc_pair_id" in user_file_columns:
|
||||
logger.info("Dropping user_file.cc_pair_id...")
|
||||
|
||||
# Drop any remaining foreign key constraints
|
||||
bind.execute(
|
||||
text(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE r RECORD;
|
||||
BEGIN
|
||||
FOR r IN (
|
||||
SELECT conname
|
||||
FROM pg_constraint c
|
||||
JOIN pg_class t ON c.conrelid = t.oid
|
||||
WHERE c.contype = 'f'
|
||||
AND t.relname = 'user_file'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM pg_attribute a
|
||||
WHERE a.attrelid = t.oid
|
||||
AND a.attname = 'cc_pair_id'
|
||||
)
|
||||
) LOOP
|
||||
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT IF EXISTS %I', r.conname);
|
||||
END LOOP;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
op.drop_column("user_file", "cc_pair_id")
|
||||
logger.info("Dropped user_file.cc_pair_id")
|
||||
|
||||
# === Step 6: Clean up any remaining constraints ===
|
||||
logger.info("Cleaning up remaining constraints...")
|
||||
|
||||
# Drop any unique constraints on removed columns
|
||||
op.execute(
|
||||
"ALTER TABLE user_file DROP CONSTRAINT IF EXISTS user_file_cc_pair_id_key"
|
||||
)
|
||||
|
||||
logger.info("Migration 6 (schema cleanup) completed successfully")
|
||||
logger.info("Legacy schema has been fully removed")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Recreate dropped columns and tables (structure only, no data)."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.warning("Downgrading schema cleanup - recreating structure only, no data!")
|
||||
|
||||
# Recreate user_file columns
|
||||
if "user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
|
||||
if "cc_pair_id" not in columns:
|
||||
op.add_column(
|
||||
"user_file", sa.Column("cc_pair_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
if "folder_id" not in columns:
|
||||
op.add_column(
|
||||
"user_file", sa.Column("folder_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Recreate chat_folder table
|
||||
if "chat_folder" not in inspector.get_table_names():
|
||||
op.create_table(
|
||||
"chat_folder",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"], ["user.id"], name="chat_folder_user_fk"
|
||||
),
|
||||
)
|
||||
|
||||
# Recreate persona__user_folder table
|
||||
if "persona__user_folder" not in inspector.get_table_names():
|
||||
op.create_table(
|
||||
"persona__user_folder",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_folder_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("persona_id", "user_folder_id"),
|
||||
sa.ForeignKeyConstraint(["persona_id"], ["persona.id"]),
|
||||
sa.ForeignKeyConstraint(["user_folder_id"], ["user_project.id"]),
|
||||
)
|
||||
|
||||
# Add folder_id back to chat_session
|
||||
if "chat_session" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("chat_session")]
|
||||
if "folder_id" not in columns:
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("folder_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Add foreign key if chat_folder exists
|
||||
if "chat_folder" in inspector.get_table_names():
|
||||
op.create_foreign_key(
|
||||
"chat_session_folder_fk",
|
||||
"chat_session",
|
||||
"chat_folder",
|
||||
["folder_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
logger.info("Downgrade completed - structure recreated but data is lost")
|
||||
@@ -0,0 +1,298 @@
|
||||
"""Migration 5: User file legacy data cleanup
|
||||
|
||||
Revision ID: 3a78dba1080a
|
||||
Revises: 7cc3fcc116c1
|
||||
Create Date: 2025-09-22 10:04:27.986294
|
||||
|
||||
This migration removes legacy user-file documents and connector_credential_pairs.
|
||||
It performs bulk deletions of obsolete data after the UUID migration.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as psql
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
from typing import List
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3a78dba1080a"
|
||||
down_revision = "7cc3fcc116c1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def batch_delete(
|
||||
bind: sa.engine.Connection,
|
||||
table_name: str,
|
||||
id_column: str,
|
||||
ids: List[str | int | uuid.UUID],
|
||||
batch_size: int = 1000,
|
||||
id_type: str = "int",
|
||||
) -> int:
|
||||
"""Delete records in batches to avoid memory issues and timeouts."""
|
||||
total_count = len(ids)
|
||||
if total_count == 0:
|
||||
return 0
|
||||
|
||||
logger.info(
|
||||
f"Starting batch deletion of {total_count} records from {table_name}..."
|
||||
)
|
||||
|
||||
# Determine appropriate ARRAY type
|
||||
if id_type == "uuid":
|
||||
array_type = psql.ARRAY(psql.UUID(as_uuid=True))
|
||||
elif id_type == "int":
|
||||
array_type = psql.ARRAY(sa.Integer())
|
||||
else:
|
||||
array_type = psql.ARRAY(sa.String())
|
||||
|
||||
total_deleted = 0
|
||||
failed_batches = []
|
||||
|
||||
for i in range(0, total_count, batch_size):
|
||||
batch_ids = ids[i : i + batch_size]
|
||||
try:
|
||||
stmt = text(
|
||||
f"DELETE FROM {table_name} WHERE {id_column} = ANY(:ids)"
|
||||
).bindparams(sa.bindparam("ids", value=batch_ids, type_=array_type))
|
||||
result = bind.execute(stmt)
|
||||
total_deleted += result.rowcount
|
||||
|
||||
# Log progress every 10 batches or at completion
|
||||
batch_num = (i // batch_size) + 1
|
||||
if batch_num % 10 == 0 or i + batch_size >= total_count:
|
||||
logger.info(
|
||||
f" Deleted {min(i + batch_size, total_count)}/{total_count} records "
|
||||
f"({total_deleted} actual) from {table_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete batch {(i // batch_size) + 1}: {e}")
|
||||
failed_batches.append((i, min(i + batch_size, total_count)))
|
||||
|
||||
if failed_batches:
|
||||
logger.warning(
|
||||
f"Failed to delete {len(failed_batches)} batches from {table_name}. "
|
||||
f"Total deleted: {total_deleted}/{total_count}"
|
||||
)
|
||||
# Fail the migration to avoid silently succeeding on partial cleanup
|
||||
raise RuntimeError(
|
||||
f"Batch deletion failed for {table_name}: "
|
||||
f"{len(failed_batches)} failed batches out of "
|
||||
f"{(total_count + batch_size - 1) // batch_size}."
|
||||
)
|
||||
|
||||
return total_deleted
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Remove legacy user-file documents and connector_credential_pairs."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.info("Starting legacy data cleanup...")
|
||||
|
||||
# === Step 1: Identify and delete user-file documents ===
|
||||
logger.info("Identifying user-file documents to delete...")
|
||||
|
||||
# Get document IDs to delete
|
||||
doc_rows = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT DISTINCT dcc.id AS document_id
|
||||
FROM document_by_connector_credential_pair dcc
|
||||
JOIN connector_credential_pair u
|
||||
ON u.connector_id = dcc.connector_id
|
||||
AND u.credential_id = dcc.credential_id
|
||||
WHERE u.is_user_file IS TRUE
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
doc_ids = [r[0] for r in doc_rows]
|
||||
|
||||
if doc_ids:
|
||||
logger.info(f"Found {len(doc_ids)} user-file documents to delete")
|
||||
|
||||
# Delete dependent rows first
|
||||
tables_to_clean = [
|
||||
("document_retrieval_feedback", "document_id"),
|
||||
("document__tag", "document_id"),
|
||||
("chunk_stats", "document_id"),
|
||||
]
|
||||
|
||||
for table_name, column_name in tables_to_clean:
|
||||
if table_name in inspector.get_table_names():
|
||||
# document_id is a string in these tables
|
||||
deleted = batch_delete(
|
||||
bind, table_name, column_name, doc_ids, id_type="str"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} records from {table_name}")
|
||||
|
||||
# Delete document_by_connector_credential_pair entries
|
||||
deleted = batch_delete(
|
||||
bind, "document_by_connector_credential_pair", "id", doc_ids, id_type="str"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} document_by_connector_credential_pair records")
|
||||
|
||||
# Delete documents themselves
|
||||
deleted = batch_delete(bind, "document", "id", doc_ids, id_type="str")
|
||||
logger.info(f"Deleted {deleted} document records")
|
||||
else:
|
||||
logger.info("No user-file documents found to delete")
|
||||
|
||||
# === Step 2: Clean up user-file connector_credential_pairs ===
|
||||
logger.info("Cleaning up user-file connector_credential_pairs...")
|
||||
|
||||
# Get cc_pair IDs
|
||||
cc_pair_rows = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id AS cc_pair_id
|
||||
FROM connector_credential_pair
|
||||
WHERE is_user_file IS TRUE
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
cc_pair_ids = [r[0] for r in cc_pair_rows]
|
||||
|
||||
if cc_pair_ids:
|
||||
logger.info(
|
||||
f"Found {len(cc_pair_ids)} user-file connector_credential_pairs to clean up"
|
||||
)
|
||||
|
||||
# Delete related records
|
||||
# Clean child tables first to satisfy foreign key constraints,
|
||||
# then the parent tables
|
||||
tables_to_clean = [
|
||||
("index_attempt_errors", "connector_credential_pair_id"),
|
||||
("index_attempt", "connector_credential_pair_id"),
|
||||
("background_error", "cc_pair_id"),
|
||||
("document_set__connector_credential_pair", "connector_credential_pair_id"),
|
||||
("user_group__connector_credential_pair", "cc_pair_id"),
|
||||
]
|
||||
|
||||
for table_name, column_name in tables_to_clean:
|
||||
if table_name in inspector.get_table_names():
|
||||
deleted = batch_delete(
|
||||
bind, table_name, column_name, cc_pair_ids, id_type="int"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} records from {table_name}")
|
||||
|
||||
# === Step 3: Identify connectors and credentials to delete ===
|
||||
logger.info("Identifying orphaned connectors and credentials...")
|
||||
|
||||
# Get connectors used only by user-file cc_pairs
|
||||
connector_rows = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT DISTINCT ccp.connector_id
|
||||
FROM connector_credential_pair ccp
|
||||
WHERE ccp.is_user_file IS TRUE
|
||||
AND ccp.connector_id != 0 -- Exclude system default
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM connector_credential_pair c2
|
||||
WHERE c2.connector_id = ccp.connector_id
|
||||
AND c2.is_user_file IS NOT TRUE
|
||||
)
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
userfile_only_connector_ids = [r[0] for r in connector_rows]
|
||||
|
||||
# Get credentials used only by user-file cc_pairs
|
||||
credential_rows = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT DISTINCT ccp.credential_id
|
||||
FROM connector_credential_pair ccp
|
||||
WHERE ccp.is_user_file IS TRUE
|
||||
AND ccp.credential_id != 0 -- Exclude public/default
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM connector_credential_pair c2
|
||||
WHERE c2.credential_id = ccp.credential_id
|
||||
AND c2.is_user_file IS NOT TRUE
|
||||
)
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
userfile_only_credential_ids = [r[0] for r in credential_rows]
|
||||
|
||||
# === Step 4: Delete the cc_pairs themselves ===
|
||||
if cc_pair_ids:
|
||||
# Remove FK dependency from user_file first
|
||||
bind.execute(
|
||||
text(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE r RECORD;
|
||||
BEGIN
|
||||
FOR r IN (
|
||||
SELECT conname
|
||||
FROM pg_constraint c
|
||||
JOIN pg_class t ON c.conrelid = t.oid
|
||||
JOIN pg_class ft ON c.confrelid = ft.oid
|
||||
WHERE c.contype = 'f'
|
||||
AND t.relname = 'user_file'
|
||||
AND ft.relname = 'connector_credential_pair'
|
||||
) LOOP
|
||||
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT IF EXISTS %I', r.conname);
|
||||
END LOOP;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Delete cc_pairs
|
||||
deleted = batch_delete(
|
||||
bind, "connector_credential_pair", "id", cc_pair_ids, id_type="int"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} connector_credential_pair records")
|
||||
|
||||
# === Step 5: Delete orphaned connectors ===
|
||||
if userfile_only_connector_ids:
|
||||
deleted = batch_delete(
|
||||
bind, "connector", "id", userfile_only_connector_ids, id_type="int"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} orphaned connector records")
|
||||
|
||||
# === Step 6: Delete orphaned credentials ===
|
||||
if userfile_only_credential_ids:
|
||||
# Clean up credential__user_group mappings first
|
||||
deleted = batch_delete(
|
||||
bind,
|
||||
"credential__user_group",
|
||||
"credential_id",
|
||||
userfile_only_credential_ids,
|
||||
id_type="int",
|
||||
)
|
||||
logger.info(f"Deleted {deleted} credential__user_group records")
|
||||
|
||||
# Delete credentials
|
||||
deleted = batch_delete(
|
||||
bind, "credential", "id", userfile_only_credential_ids, id_type="int"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} orphaned credential records")
|
||||
|
||||
logger.info("Migration 5 (legacy data cleanup) completed successfully")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Cannot restore deleted data - requires backup restoration."""
|
||||
|
||||
logger.error("CRITICAL: Downgrading data cleanup cannot restore deleted data!")
|
||||
logger.error("Data restoration requires backup files or database backup.")
|
||||
|
||||
raise NotImplementedError(
|
||||
"Downgrade of legacy data cleanup is not supported. "
|
||||
"Deleted data must be restored from backups."
|
||||
)
|
||||
@@ -0,0 +1,380 @@
|
||||
"""merge_default_assistants_into_unified
|
||||
|
||||
Revision ID: 505c488f6662
|
||||
Revises: d09fc20a3c66
|
||||
Create Date: 2025-09-09 19:00:56.816626
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import NamedTuple
|
||||
from uuid import UUID
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "505c488f6662"
|
||||
down_revision = "d09fc20a3c66"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Constants for the unified assistant
|
||||
UNIFIED_ASSISTANT_NAME = "Assistant"
|
||||
UNIFIED_ASSISTANT_DESCRIPTION = (
|
||||
"Your AI assistant with search, web browsing, and image generation capabilities."
|
||||
)
|
||||
UNIFIED_ASSISTANT_NUM_CHUNKS = 25
|
||||
UNIFIED_ASSISTANT_DISPLAY_PRIORITY = 0
|
||||
UNIFIED_ASSISTANT_LLM_FILTER_EXTRACTION = True
|
||||
UNIFIED_ASSISTANT_LLM_RELEVANCE_FILTER = False
|
||||
UNIFIED_ASSISTANT_RECENCY_BIAS = "AUTO" # NOTE: needs to be capitalized
|
||||
UNIFIED_ASSISTANT_CHUNKS_ABOVE = 0
|
||||
UNIFIED_ASSISTANT_CHUNKS_BELOW = 0
|
||||
UNIFIED_ASSISTANT_DATETIME_AWARE = True
|
||||
|
||||
# NOTE: tool specific prompts are handled on the fly and automatically injected
|
||||
# into the prompt before passing to the LLM.
|
||||
DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the \
|
||||
user's intent, ask clarifying questions when needed, think step-by-step through complex problems, \
|
||||
provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always \
|
||||
prioritize being truthful, nuanced, insightful, and efficient.
|
||||
The current date is [[CURRENT_DATETIME]]
|
||||
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make \
|
||||
your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, \
|
||||
symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use Markdown horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
""".strip()
|
||||
|
||||
|
||||
INSERT_DICT: dict[str, Any] = {
|
||||
"name": UNIFIED_ASSISTANT_NAME,
|
||||
"description": UNIFIED_ASSISTANT_DESCRIPTION,
|
||||
"system_prompt": DEFAULT_SYSTEM_PROMPT,
|
||||
"num_chunks": UNIFIED_ASSISTANT_NUM_CHUNKS,
|
||||
"display_priority": UNIFIED_ASSISTANT_DISPLAY_PRIORITY,
|
||||
"llm_filter_extraction": UNIFIED_ASSISTANT_LLM_FILTER_EXTRACTION,
|
||||
"llm_relevance_filter": UNIFIED_ASSISTANT_LLM_RELEVANCE_FILTER,
|
||||
"recency_bias": UNIFIED_ASSISTANT_RECENCY_BIAS,
|
||||
"chunks_above": UNIFIED_ASSISTANT_CHUNKS_ABOVE,
|
||||
"chunks_below": UNIFIED_ASSISTANT_CHUNKS_BELOW,
|
||||
"datetime_aware": UNIFIED_ASSISTANT_DATETIME_AWARE,
|
||||
}
|
||||
|
||||
GENERAL_ASSISTANT_ID = -1
|
||||
ART_ASSISTANT_ID = -3
|
||||
|
||||
|
||||
class UserRow(NamedTuple):
|
||||
"""Typed representation of user row from database query."""
|
||||
|
||||
id: UUID
|
||||
chosen_assistants: list[int] | None
|
||||
visible_assistants: list[int] | None
|
||||
hidden_assistants: list[int] | None
|
||||
pinned_assistants: list[int] | None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
|
||||
# Add tools to the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": web_search_tool[0]},
|
||||
)
|
||||
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
@@ -0,0 +1,115 @@
|
||||
"""add research agent database tables and chat message research fields
|
||||
|
||||
Revision ID: 5ae8240accb3
|
||||
Revises: b558f51620b4
|
||||
Create Date: 2025-08-06 14:29:24.691388
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5ae8240accb3"
|
||||
down_revision = "b558f51620b4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add research_type and research_plan columns to chat_message table
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_type", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_plan", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
# Create research_agent_iteration table
|
||||
op.create_table(
|
||||
"research_agent_iteration",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column(
|
||||
"primary_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("purpose", sa.String(), nullable=True),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"primary_question_id",
|
||||
"iteration_nr",
|
||||
name="_research_agent_iteration_unique_constraint",
|
||||
),
|
||||
)
|
||||
|
||||
# Create research_agent_iteration_sub_step table
|
||||
op.create_table(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column(
|
||||
"primary_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"parent_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("sub_step_instructions", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"sub_step_tool_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("tool.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.Column("sub_answer", sa.String(), nullable=True),
|
||||
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("claims", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["primary_question_id", "iteration_nr"],
|
||||
[
|
||||
"research_agent_iteration.primary_question_id",
|
||||
"research_agent_iteration.iteration_nr",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop tables in reverse order
|
||||
op.drop_table("research_agent_iteration_sub_step")
|
||||
op.drop_table("research_agent_iteration")
|
||||
|
||||
# Remove columns from chat_message table
|
||||
op.drop_column("chat_message", "research_plan")
|
||||
op.drop_column("chat_message", "research_type")
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Migration 4: User file UUID primary key swap
|
||||
|
||||
Revision ID: 7cc3fcc116c1
|
||||
Revises: 16c37a30adf2
|
||||
Create Date: 2025-09-22 09:54:38.292952
|
||||
|
||||
This migration performs the critical UUID primary key swap on user_file table.
|
||||
It updates all foreign key references to use UUIDs instead of integers.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as psql
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7cc3fcc116c1"
|
||||
down_revision = "16c37a30adf2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Swap user_file primary key from integer to UUID."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
# Verify we're in the expected state
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
if "new_id" not in user_file_columns:
|
||||
logger.warning(
|
||||
"user_file.new_id not found - migration may have already been applied"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("Starting UUID primary key swap...")
|
||||
|
||||
# === Step 1: Update persona__user_file foreign key to UUID ===
|
||||
logger.info("Updating persona__user_file foreign key...")
|
||||
|
||||
# Drop existing foreign key constraints
|
||||
op.execute(
|
||||
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_uuid_fkey"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_fkey"
|
||||
)
|
||||
|
||||
# Create new foreign key to user_file.new_id
|
||||
op.create_foreign_key(
|
||||
"persona__user_file_user_file_id_fkey",
|
||||
"persona__user_file",
|
||||
"user_file",
|
||||
local_cols=["user_file_id_uuid"],
|
||||
remote_cols=["new_id"],
|
||||
)
|
||||
|
||||
# Drop the old integer column and rename UUID column
|
||||
op.execute("ALTER TABLE persona__user_file DROP COLUMN IF EXISTS user_file_id")
|
||||
op.alter_column(
|
||||
"persona__user_file",
|
||||
"user_file_id_uuid",
|
||||
new_column_name="user_file_id",
|
||||
existing_type=psql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Recreate composite primary key
|
||||
op.execute(
|
||||
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_pkey"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE persona__user_file ADD PRIMARY KEY (persona_id, user_file_id)"
|
||||
)
|
||||
|
||||
logger.info("Updated persona__user_file to use UUID foreign key")
|
||||
|
||||
# === Step 2: Perform the primary key swap on user_file ===
|
||||
logger.info("Swapping user_file primary key to UUID...")
|
||||
|
||||
# Drop the primary key constraint
|
||||
op.execute("ALTER TABLE user_file DROP CONSTRAINT IF EXISTS user_file_pkey")
|
||||
|
||||
# Drop the old id column and rename new_id to id
|
||||
op.execute("ALTER TABLE user_file DROP COLUMN IF EXISTS id")
|
||||
op.alter_column(
|
||||
"user_file",
|
||||
"new_id",
|
||||
new_column_name="id",
|
||||
existing_type=psql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Set default for new inserts
|
||||
op.alter_column(
|
||||
"user_file",
|
||||
"id",
|
||||
existing_type=psql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
)
|
||||
|
||||
# Create new primary key
|
||||
op.execute("ALTER TABLE user_file ADD PRIMARY KEY (id)")
|
||||
|
||||
logger.info("Swapped user_file primary key to UUID")
|
||||
|
||||
# === Step 3: Update foreign key constraints ===
|
||||
logger.info("Updating foreign key constraints...")
|
||||
|
||||
# Recreate persona__user_file foreign key to point to user_file.id
|
||||
# Drop existing FK first to break dependency on the unique constraint
|
||||
op.execute(
|
||||
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_fkey"
|
||||
)
|
||||
# Drop the unique constraint on (formerly) new_id BEFORE recreating the FK,
|
||||
# so the FK will bind to the primary key instead of the unique index.
|
||||
op.execute("ALTER TABLE user_file DROP CONSTRAINT IF EXISTS uq_user_file_new_id")
|
||||
# Now recreate FK to the primary key column
|
||||
op.create_foreign_key(
|
||||
"persona__user_file_user_file_id_fkey",
|
||||
"persona__user_file",
|
||||
"user_file",
|
||||
local_cols=["user_file_id"],
|
||||
remote_cols=["id"],
|
||||
)
|
||||
|
||||
# Add foreign keys for project__user_file
|
||||
existing_fks = inspector.get_foreign_keys("project__user_file")
|
||||
|
||||
has_user_file_fk = any(
|
||||
fk.get("referred_table") == "user_file"
|
||||
and fk.get("constrained_columns") == ["user_file_id"]
|
||||
for fk in existing_fks
|
||||
)
|
||||
|
||||
if not has_user_file_fk:
|
||||
op.create_foreign_key(
|
||||
"fk_project__user_file_user_file_id",
|
||||
"project__user_file",
|
||||
"user_file",
|
||||
["user_file_id"],
|
||||
["id"],
|
||||
)
|
||||
logger.info("Added project__user_file -> user_file foreign key")
|
||||
|
||||
has_project_fk = any(
|
||||
fk.get("referred_table") == "user_project"
|
||||
and fk.get("constrained_columns") == ["project_id"]
|
||||
for fk in existing_fks
|
||||
)
|
||||
|
||||
if not has_project_fk:
|
||||
op.create_foreign_key(
|
||||
"fk_project__user_file_project_id",
|
||||
"project__user_file",
|
||||
"user_project",
|
||||
["project_id"],
|
||||
["id"],
|
||||
)
|
||||
logger.info("Added project__user_file -> user_project foreign key")
|
||||
|
||||
# === Step 4: Mark files for document_id migration ===
|
||||
logger.info("Marking files for background document_id migration...")
|
||||
|
||||
logger.info("Migration 4 (UUID primary key swap) completed successfully")
|
||||
logger.info(
|
||||
"NOTE: Background task will update document IDs in Vespa and search_doc"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Revert UUID primary key back to integer (data destructive!)."""
|
||||
|
||||
logger.error("CRITICAL: Downgrading UUID primary key swap is data destructive!")
|
||||
logger.error(
|
||||
"This will break all UUID-based references created after the migration."
|
||||
)
|
||||
logger.error("Only proceed if absolutely necessary and have backups.")
|
||||
|
||||
# The downgrade would need to:
|
||||
# 1. Add back integer columns
|
||||
# 2. Generate new sequential IDs
|
||||
# 3. Update all foreign key references
|
||||
# 4. Swap primary keys back
|
||||
# This is complex and risky, so we raise an error instead
|
||||
|
||||
raise NotImplementedError(
|
||||
"Downgrade of UUID primary key swap is not supported due to data loss risk. "
|
||||
"Manual intervention with data backup/restore is required."
|
||||
)
|
||||
@@ -0,0 +1,249 @@
|
||||
"""add_mcp_server_and_connection_config_models
|
||||
|
||||
Revision ID: 7ed603b64d5a
|
||||
Revises: b329d00a9ea6
|
||||
Create Date: 2025-07-28 17:35:59.900680
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7ed603b64d5a"
|
||||
down_revision = "b329d00a9ea6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create tables and columns for MCP Server support"""
|
||||
|
||||
# 1. MCP Server main table (no FK constraints yet to avoid circular refs)
|
||||
op.create_table(
|
||||
"mcp_server",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("owner", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column("server_url", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"auth_type",
|
||||
sa.Enum(
|
||||
MCPAuthenticationType,
|
||||
name="mcp_authentication_type",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("admin_connection_config_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# 2. MCP Connection Config table (can reference mcp_server now that it exists)
|
||||
op.create_table(
|
||||
"mcp_connection_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("mcp_server_id", sa.Integer(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False, default=""),
|
||||
sa.Column("config", sa.LargeBinary(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
|
||||
),
|
||||
)
|
||||
|
||||
# Helpful indexes
|
||||
op.create_index(
|
||||
"ix_mcp_connection_config_server_user",
|
||||
"mcp_connection_config",
|
||||
["mcp_server_id", "user_email"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_mcp_connection_config_user_email",
|
||||
"mcp_connection_config",
|
||||
["user_email"],
|
||||
)
|
||||
|
||||
# 3. Add the back-references from mcp_server to connection configs
|
||||
op.create_foreign_key(
|
||||
"mcp_server_admin_config_fk",
|
||||
"mcp_server",
|
||||
"mcp_connection_config",
|
||||
["admin_connection_config_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
# 4. Association / access-control tables
|
||||
op.create_table(
|
||||
"mcp_server__user",
|
||||
sa.Column("mcp_server_id", sa.Integer(), primary_key=True),
|
||||
sa.Column("user_id", sa.UUID(), primary_key=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"mcp_server__user_group",
|
||||
sa.Column("mcp_server_id", sa.Integer(), primary_key=True),
|
||||
sa.Column("user_group_id", sa.Integer(), primary_key=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_group_id"], ["user_group.id"]),
|
||||
)
|
||||
|
||||
# 5. Update existing `tool` table – allow tools to belong to an MCP server
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column("mcp_server_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
# Add column for MCP tool input schema
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column("mcp_input_schema", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"tool_mcp_server_fk",
|
||||
"tool",
|
||||
"mcp_server",
|
||||
["mcp_server_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# 6. Update persona__tool foreign keys to cascade delete
|
||||
# This ensures that when a tool is deleted (including via MCP server deletion),
|
||||
# the corresponding persona__tool rows are also deleted
|
||||
op.drop_constraint(
|
||||
"persona__tool_tool_id_fkey", "persona__tool", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"persona__tool_persona_id_fkey", "persona__tool", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"persona__tool_persona_id_fkey",
|
||||
"persona__tool",
|
||||
"persona",
|
||||
["persona_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"persona__tool_tool_id_fkey",
|
||||
"persona__tool",
|
||||
"tool",
|
||||
["tool_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# 7. Update research_agent_iteration_sub_step foreign key to SET NULL on delete
|
||||
# This ensures that when a tool is deleted, the sub_step_tool_id is set to NULL
|
||||
# instead of causing a foreign key constraint violation
|
||||
op.drop_constraint(
|
||||
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
"tool",
|
||||
["sub_step_tool_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop all MCP-related tables / columns"""
|
||||
|
||||
# # # 1. Drop FK & columns from tool
|
||||
# op.drop_constraint("tool_mcp_server_fk", "tool", type_="foreignkey")
|
||||
op.execute("DELETE FROM tool WHERE mcp_server_id IS NOT NULL")
|
||||
|
||||
op.drop_constraint(
|
||||
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
"tool",
|
||||
["sub_step_tool_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Restore original persona__tool foreign keys (without CASCADE)
|
||||
op.drop_constraint(
|
||||
"persona__tool_persona_id_fkey", "persona__tool", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"persona__tool_tool_id_fkey", "persona__tool", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"persona__tool_persona_id_fkey",
|
||||
"persona__tool",
|
||||
"persona",
|
||||
["persona_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"persona__tool_tool_id_fkey",
|
||||
"persona__tool",
|
||||
"tool",
|
||||
["tool_id"],
|
||||
["id"],
|
||||
)
|
||||
op.drop_column("tool", "mcp_input_schema")
|
||||
op.drop_column("tool", "mcp_server_id")
|
||||
|
||||
# 2. Drop association tables
|
||||
op.drop_table("mcp_server__user_group")
|
||||
op.drop_table("mcp_server__user")
|
||||
|
||||
# 3. Drop FK from mcp_server to connection configs
|
||||
op.drop_constraint("mcp_server_admin_config_fk", "mcp_server", type_="foreignkey")
|
||||
|
||||
# 4. Drop connection config indexes & table
|
||||
op.drop_index(
|
||||
"ix_mcp_connection_config_user_email", table_name="mcp_connection_config"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_mcp_connection_config_server_user", table_name="mcp_connection_config"
|
||||
)
|
||||
op.drop_table("mcp_connection_config")
|
||||
|
||||
# 5. Finally drop mcp_server table
|
||||
op.drop_table("mcp_server")
|
||||
@@ -0,0 +1,38 @@
|
||||
"""drop include citations
|
||||
|
||||
Revision ID: 8818cf73fa1a
|
||||
Revises: 7ed603b64d5a
|
||||
Create Date: 2025-09-02 19:43:50.060680
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8818cf73fa1a"
|
||||
down_revision = "7ed603b64d5a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("prompt", "include_citations")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"prompt",
|
||||
sa.Column(
|
||||
"include_citations",
|
||||
sa.BOOLEAN(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
# Set include_citations based on prompt name: FALSE for ImageGeneration, TRUE for others
|
||||
op.execute(
|
||||
sa.text(
|
||||
"UPDATE prompt SET include_citations = CASE WHEN name = 'ImageGeneration' THEN FALSE ELSE TRUE END"
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,257 @@
|
||||
"""Migration 1: User file schema additions
|
||||
|
||||
Revision ID: 9b66d3156fc6
|
||||
Revises: b4ef3ae0bf6e
|
||||
Create Date: 2025-09-22 09:42:06.086732
|
||||
|
||||
This migration adds new columns and tables without modifying existing data.
|
||||
It is safe to run and can be easily rolled back.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as psql
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9b66d3156fc6"
|
||||
down_revision = "b4ef3ae0bf6e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add new columns and tables without modifying existing data."""
|
||||
|
||||
# Enable pgcrypto for UUID generation
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
# === USER_FILE: Add new columns ===
|
||||
logger.info("Adding new columns to user_file table...")
|
||||
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
|
||||
# Check if ID is already UUID (in case of re-run after partial migration)
|
||||
id_is_uuid = any(
|
||||
col["name"] == "id" and "uuid" in str(col["type"]).lower()
|
||||
for col in inspector.get_columns("user_file")
|
||||
)
|
||||
|
||||
# Add transitional UUID column only if ID is not already UUID
|
||||
if "new_id" not in user_file_columns and not id_is_uuid:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"new_id",
|
||||
psql.UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
),
|
||||
)
|
||||
op.create_unique_constraint("uq_user_file_new_id", "user_file", ["new_id"])
|
||||
logger.info("Added new_id column to user_file")
|
||||
|
||||
# Add status column
|
||||
if "status" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"PROCESSING",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELED",
|
||||
name="userfilestatus",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
server_default="PROCESSING",
|
||||
),
|
||||
)
|
||||
logger.info("Added status column to user_file")
|
||||
|
||||
# Add other tracking columns
|
||||
if "chunk_count" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file", sa.Column("chunk_count", sa.Integer(), nullable=True)
|
||||
)
|
||||
logger.info("Added chunk_count column to user_file")
|
||||
|
||||
if "last_accessed_at" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column("last_accessed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
logger.info("Added last_accessed_at column to user_file")
|
||||
|
||||
if "needs_project_sync" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"needs_project_sync",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
logger.info("Added needs_project_sync column to user_file")
|
||||
|
||||
if "last_project_sync_at" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"last_project_sync_at", sa.DateTime(timezone=True), nullable=True
|
||||
),
|
||||
)
|
||||
logger.info("Added last_project_sync_at column to user_file")
|
||||
|
||||
if "document_id_migrated" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"document_id_migrated",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("true"),
|
||||
),
|
||||
)
|
||||
logger.info("Added document_id_migrated column to user_file")
|
||||
|
||||
# === USER_FOLDER -> USER_PROJECT rename ===
|
||||
table_names = set(inspector.get_table_names())
|
||||
|
||||
if "user_folder" in table_names:
|
||||
logger.info("Updating user_folder table...")
|
||||
# Make description nullable first
|
||||
op.alter_column("user_folder", "description", nullable=True)
|
||||
|
||||
# Rename table if user_project doesn't exist
|
||||
if "user_project" not in table_names:
|
||||
op.execute("ALTER TABLE user_folder RENAME TO user_project")
|
||||
logger.info("Renamed user_folder to user_project")
|
||||
elif "user_project" in table_names:
|
||||
# If already renamed, ensure column nullability
|
||||
project_cols = [col["name"] for col in inspector.get_columns("user_project")]
|
||||
if "description" in project_cols:
|
||||
op.alter_column("user_project", "description", nullable=True)
|
||||
|
||||
# Add instructions column to user_project
|
||||
inspector = sa.inspect(bind) # Refresh after rename
|
||||
if "user_project" in inspector.get_table_names():
|
||||
project_columns = [col["name"] for col in inspector.get_columns("user_project")]
|
||||
if "instructions" not in project_columns:
|
||||
op.add_column(
|
||||
"user_project",
|
||||
sa.Column("instructions", sa.String(), nullable=True),
|
||||
)
|
||||
logger.info("Added instructions column to user_project")
|
||||
|
||||
# === CHAT_SESSION: Add project_id ===
|
||||
chat_session_columns = [
|
||||
col["name"] for col in inspector.get_columns("chat_session")
|
||||
]
|
||||
if "project_id" not in chat_session_columns:
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("project_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
logger.info("Added project_id column to chat_session")
|
||||
|
||||
# === PERSONA__USER_FILE: Add UUID column ===
|
||||
persona_user_file_columns = [
|
||||
col["name"] for col in inspector.get_columns("persona__user_file")
|
||||
]
|
||||
if "user_file_id_uuid" not in persona_user_file_columns:
|
||||
op.add_column(
|
||||
"persona__user_file",
|
||||
sa.Column("user_file_id_uuid", psql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
logger.info("Added user_file_id_uuid column to persona__user_file")
|
||||
|
||||
# === PROJECT__USER_FILE: Create new table ===
|
||||
if "project__user_file" not in inspector.get_table_names():
|
||||
op.create_table(
|
||||
"project__user_file",
|
||||
sa.Column("project_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_file_id", psql.UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("project_id", "user_file_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_project__user_file_user_file_id",
|
||||
"project__user_file",
|
||||
["user_file_id"],
|
||||
)
|
||||
logger.info("Created project__user_file table")
|
||||
|
||||
logger.info("Migration 1 (schema additions) completed successfully")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove added columns and tables."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.info("Starting downgrade of schema additions...")
|
||||
|
||||
# Drop project__user_file table
|
||||
if "project__user_file" in inspector.get_table_names():
|
||||
op.drop_index("idx_project__user_file_user_file_id", "project__user_file")
|
||||
op.drop_table("project__user_file")
|
||||
logger.info("Dropped project__user_file table")
|
||||
|
||||
# Remove columns from persona__user_file
|
||||
if "persona__user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("persona__user_file")]
|
||||
if "user_file_id_uuid" in columns:
|
||||
op.drop_column("persona__user_file", "user_file_id_uuid")
|
||||
logger.info("Dropped user_file_id_uuid from persona__user_file")
|
||||
|
||||
# Remove columns from chat_session
|
||||
if "chat_session" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("chat_session")]
|
||||
if "project_id" in columns:
|
||||
op.drop_column("chat_session", "project_id")
|
||||
logger.info("Dropped project_id from chat_session")
|
||||
|
||||
# Rename user_project back to user_folder and remove instructions
|
||||
if "user_project" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("user_project")]
|
||||
if "instructions" in columns:
|
||||
op.drop_column("user_project", "instructions")
|
||||
op.execute("ALTER TABLE user_project RENAME TO user_folder")
|
||||
op.alter_column("user_folder", "description", nullable=False)
|
||||
logger.info("Renamed user_project back to user_folder")
|
||||
|
||||
# Remove columns from user_file
|
||||
if "user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
|
||||
columns_to_drop = [
|
||||
"document_id_migrated",
|
||||
"last_project_sync_at",
|
||||
"needs_project_sync",
|
||||
"last_accessed_at",
|
||||
"chunk_count",
|
||||
"status",
|
||||
]
|
||||
|
||||
for col in columns_to_drop:
|
||||
if col in columns:
|
||||
op.drop_column("user_file", col)
|
||||
logger.info(f"Dropped {col} from user_file")
|
||||
|
||||
if "new_id" in columns:
|
||||
op.drop_constraint("uq_user_file_new_id", "user_file", type_="unique")
|
||||
op.drop_column("user_file", "new_id")
|
||||
logger.info("Dropped new_id from user_file")
|
||||
|
||||
# Drop enum type if no columns use it
|
||||
bind.execute(sa.text("DROP TYPE IF EXISTS userfilestatus"))
|
||||
|
||||
logger.info("Downgrade completed successfully")
|
||||
@@ -0,0 +1,225 @@
|
||||
"""merge prompt into persona
|
||||
|
||||
Revision ID: abbfec3a5ac5
|
||||
Revises: 8818cf73fa1a
|
||||
Create Date: 2024-12-19 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abbfec3a5ac5"
|
||||
down_revision = "8818cf73fa1a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
MAX_PROMPT_LENGTH = 5_000_000
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""NOTE: Prompts without any Personas will just be lost."""
|
||||
# Step 1: Add new columns to persona table (only if they don't exist)
|
||||
|
||||
# Check if columns exist before adding them
|
||||
connection = op.get_bind()
|
||||
inspector = sa.inspect(connection)
|
||||
existing_columns = [col["name"] for col in inspector.get_columns("persona")]
|
||||
|
||||
if "system_prompt" not in existing_columns:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
if "task_prompt" not in existing_columns:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
if "datetime_aware" not in existing_columns:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"datetime_aware", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
)
|
||||
|
||||
# Step 2: Migrate data from prompt table to persona table (only if tables exist)
|
||||
existing_tables = inspector.get_table_names()
|
||||
|
||||
if "prompt" in existing_tables and "persona__prompt" in existing_tables:
|
||||
# For personas that have associated prompts, copy the prompt data
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET
|
||||
system_prompt = p.system_prompt,
|
||||
task_prompt = p.task_prompt,
|
||||
datetime_aware = p.datetime_aware
|
||||
FROM (
|
||||
-- Get the first prompt for each persona (in case there are multiple)
|
||||
SELECT DISTINCT ON (pp.persona_id)
|
||||
pp.persona_id,
|
||||
pr.system_prompt,
|
||||
pr.task_prompt,
|
||||
pr.datetime_aware
|
||||
FROM persona__prompt pp
|
||||
JOIN prompt pr ON pp.prompt_id = pr.id
|
||||
) p
|
||||
WHERE persona.id = p.persona_id
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 3: Update chat_message references
|
||||
# Since chat messages referenced prompt_id, we need to update them to use persona_id
|
||||
# This is complex as we need to map from prompt_id to persona_id
|
||||
|
||||
# Check if chat_message has prompt_id column
|
||||
chat_message_columns = [
|
||||
col["name"] for col in inspector.get_columns("chat_message")
|
||||
]
|
||||
if "prompt_id" in chat_message_columns:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_message
|
||||
DROP CONSTRAINT IF EXISTS chat_message__prompt_fk
|
||||
"""
|
||||
)
|
||||
op.drop_column("chat_message", "prompt_id")
|
||||
|
||||
# Step 4: Handle personas without prompts - set default values if needed (always run this)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET
|
||||
system_prompt = COALESCE(system_prompt, ''),
|
||||
task_prompt = COALESCE(task_prompt, '')
|
||||
WHERE system_prompt IS NULL OR task_prompt IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 5: Drop the persona__prompt association table (if it exists)
|
||||
if "persona__prompt" in existing_tables:
|
||||
op.drop_table("persona__prompt")
|
||||
|
||||
# Step 6: Drop the prompt table (if it exists)
|
||||
if "prompt" in existing_tables:
|
||||
op.drop_table("prompt")
|
||||
|
||||
# Step 7: Make system_prompt and task_prompt non-nullable after migration (only if they exist)
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
nullable=False,
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
nullable=False,
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Step 1: Recreate the prompt table
|
||||
op.create_table(
|
||||
"prompt",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=False),
|
||||
sa.Column("system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False),
|
||||
sa.Column("task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False),
|
||||
sa.Column(
|
||||
"datetime_aware", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
sa.Column(
|
||||
"default_prompt", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Step 2: Recreate the persona__prompt association table
|
||||
op.create_table(
|
||||
"persona__prompt",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("prompt_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["prompt_id"],
|
||||
["prompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "prompt_id"),
|
||||
)
|
||||
|
||||
# Step 3: Migrate data back from persona to prompt table
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO prompt (
|
||||
name,
|
||||
description,
|
||||
system_prompt,
|
||||
task_prompt,
|
||||
datetime_aware,
|
||||
default_prompt,
|
||||
deleted,
|
||||
user_id
|
||||
)
|
||||
SELECT
|
||||
CONCAT('Prompt for ', name),
|
||||
description,
|
||||
system_prompt,
|
||||
task_prompt,
|
||||
datetime_aware,
|
||||
is_default_persona,
|
||||
deleted,
|
||||
user_id
|
||||
FROM persona
|
||||
WHERE system_prompt IS NOT NULL AND system_prompt != ''
|
||||
RETURNING id, name
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 4: Re-establish persona__prompt relationships
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO persona__prompt (persona_id, prompt_id)
|
||||
SELECT
|
||||
p.id as persona_id,
|
||||
pr.id as prompt_id
|
||||
FROM persona p
|
||||
JOIN prompt pr ON pr.name = CONCAT('Prompt for ', p.name)
|
||||
WHERE p.system_prompt IS NOT NULL AND p.system_prompt != ''
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 5: Add prompt_id column back to chat_message
|
||||
op.add_column("chat_message", sa.Column("prompt_id", sa.Integer(), nullable=True))
|
||||
|
||||
# Step 6: Re-establish foreign key constraint
|
||||
op.create_foreign_key(
|
||||
"chat_message__prompt_fk", "chat_message", "prompt", ["prompt_id"], ["id"]
|
||||
)
|
||||
|
||||
# Step 7: Remove columns from persona table
|
||||
op.drop_column("persona", "datetime_aware")
|
||||
op.drop_column("persona", "task_prompt")
|
||||
op.drop_column("persona", "system_prompt")
|
||||
123
backend/alembic/versions/b30353be4eec_add_mcp_auth_performer.py
Normal file
123
backend/alembic/versions/b30353be4eec_add_mcp_auth_performer.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""add_mcp_auth_performer
|
||||
|
||||
Revision ID: b30353be4eec
|
||||
Revises: 2b75d0a8ffcb
|
||||
Create Date: 2025-09-13 14:58:08.413534
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from onyx.db.enums import MCPAuthenticationPerformer, MCPTransport
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b30353be4eec"
|
||||
down_revision = "2b75d0a8ffcb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""moving to a better way of handling auth performer and transport"""
|
||||
# Add nullable column first for backward compatibility
|
||||
op.add_column(
|
||||
"mcp_server",
|
||||
sa.Column(
|
||||
"auth_performer",
|
||||
sa.Enum(MCPAuthenticationPerformer, native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"mcp_server",
|
||||
sa.Column(
|
||||
"transport",
|
||||
sa.Enum(MCPTransport, native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# # Backfill values using existing data and inference rules
|
||||
bind = op.get_bind()
|
||||
|
||||
# 1) OAUTH servers are always PER_USER
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE mcp_server
|
||||
SET auth_performer = 'PER_USER'
|
||||
WHERE auth_type = 'OAUTH'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 2) If there is no admin connection config, mark as ADMIN (and not set yet)
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE mcp_server
|
||||
SET auth_performer = 'ADMIN'
|
||||
WHERE admin_connection_config_id IS NULL
|
||||
AND auth_performer IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 3) If there exists any user-specific connection config (user_email != ''), mark as PER_USER
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE mcp_server AS ms
|
||||
SET auth_performer = 'PER_USER'
|
||||
FROM mcp_connection_config AS mcc
|
||||
WHERE mcc.mcp_server_id = ms.id
|
||||
AND COALESCE(mcc.user_email, '') <> ''
|
||||
AND ms.auth_performer IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 4) Default any remaining nulls to ADMIN (covers API_TOKEN admin-managed and NONE)
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE mcp_server
|
||||
SET auth_performer = 'ADMIN'
|
||||
WHERE auth_performer IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Finally, make the column non-nullable
|
||||
op.alter_column(
|
||||
"mcp_server",
|
||||
"auth_performer",
|
||||
existing_type=sa.Enum(MCPAuthenticationPerformer, native_enum=False),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Backfill transport for existing rows to STREAMABLE_HTTP, then make non-nullable
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE mcp_server
|
||||
SET transport = 'STREAMABLE_HTTP'
|
||||
WHERE transport IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"mcp_server",
|
||||
"transport",
|
||||
existing_type=sa.Enum(MCPTransport, native_enum=False),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""remove cols"""
|
||||
op.drop_column("mcp_server", "transport")
|
||||
op.drop_column("mcp_server", "auth_performer")
|
||||
@@ -0,0 +1,38 @@
|
||||
"""Adding assistant-specific user preferences
|
||||
|
||||
Revision ID: b329d00a9ea6
|
||||
Revises: f9b8c7d6e5a4
|
||||
Create Date: 2025-08-26 23:14:44.592985
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b329d00a9ea6"
|
||||
down_revision = "f9b8c7d6e5a4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"assistant__user_specific_config",
|
||||
sa.Column("assistant_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("disabled_tool_ids", postgresql.ARRAY(sa.Integer()), nullable=False),
|
||||
sa.ForeignKeyConstraint(["assistant_id"], ["persona.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("assistant_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("assistant__user_specific_config")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add_user_oauth_token_to_slack_bot
|
||||
|
||||
Revision ID: b4ef3ae0bf6e
|
||||
Revises: 505c488f6662
|
||||
Create Date: 2025-08-26 17:47:41.788462
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b4ef3ae0bf6e"
|
||||
down_revision = "505c488f6662"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add user_token column to slack_bot table
|
||||
op.add_column("slack_bot", sa.Column("user_token", sa.LargeBinary(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove user_token column from slack_bot table
|
||||
op.drop_column("slack_bot", "user_token")
|
||||
@@ -0,0 +1,43 @@
|
||||
"""adjust prompt length
|
||||
|
||||
Revision ID: b7ec9b5b505f
|
||||
Revises: abbfec3a5ac5
|
||||
Create Date: 2025-09-10 18:51:15.629197
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7ec9b5b505f"
|
||||
down_revision = "abbfec3a5ac5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
MAX_PROMPT_LENGTH = 5_000_000
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# NOTE: need to run this since the previous migration PREVIOUSLY set the length to 8000
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
existing_type=sa.String(length=8000),
|
||||
type_=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
existing_type=sa.String(length=8000),
|
||||
type_=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Downgrade not necessary
|
||||
pass
|
||||
@@ -0,0 +1,147 @@
|
||||
"""migrate_agent_sub_questions_to_research_iterations
|
||||
|
||||
Revision ID: bd7c3bf8beba
|
||||
Revises: f8a9b2c3d4e5
|
||||
Create Date: 2025-08-18 11:33:27.098287
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bd7c3bf8beba"
|
||||
down_revision = "f8a9b2c3d4e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get connection to execute raw SQL
|
||||
connection = op.get_bind()
|
||||
|
||||
# First, insert data into research_agent_iteration table
|
||||
# This creates one iteration record per primary_question_id using the earliest time_created
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO research_agent_iteration (primary_question_id, created_at, iteration_nr, purpose, reasoning)
|
||||
SELECT
|
||||
primary_question_id,
|
||||
MIN(time_created) as created_at,
|
||||
1 as iteration_nr,
|
||||
'Generating and researching subquestions' as purpose,
|
||||
'(No previous reasoning)' as reasoning
|
||||
FROM agent__sub_question
|
||||
JOIN chat_message on agent__sub_question.primary_question_id = chat_message.id
|
||||
WHERE primary_question_id IS NOT NULL
|
||||
AND chat_message.is_agentic = true
|
||||
GROUP BY primary_question_id
|
||||
ON CONFLICT DO NOTHING;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Then, insert data into research_agent_iteration_sub_step table
|
||||
# This migrates each sub-question as a sub-step
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO research_agent_iteration_sub_step (
|
||||
primary_question_id,
|
||||
iteration_nr,
|
||||
iteration_sub_step_nr,
|
||||
created_at,
|
||||
sub_step_instructions,
|
||||
sub_step_tool_id,
|
||||
sub_answer,
|
||||
cited_doc_results
|
||||
)
|
||||
SELECT
|
||||
primary_question_id,
|
||||
1 as iteration_nr,
|
||||
level_question_num as iteration_sub_step_nr,
|
||||
time_created as created_at,
|
||||
sub_question as sub_step_instructions,
|
||||
1 as sub_step_tool_id,
|
||||
sub_answer,
|
||||
sub_question_doc_results as cited_doc_results
|
||||
FROM agent__sub_question
|
||||
JOIN chat_message on agent__sub_question.primary_question_id = chat_message.id
|
||||
WHERE chat_message.is_agentic = true
|
||||
AND primary_question_id IS NOT NULL
|
||||
ON CONFLICT DO NOTHING;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Update chat_message records: set legacy agentic type and answer purpose for existing agentic messages
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET research_answer_purpose = 'ANSWER'
|
||||
WHERE is_agentic = true
|
||||
AND research_type IS NULL and
|
||||
message_type = 'ASSISTANT';
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET research_type = 'LEGACY_AGENTIC'
|
||||
WHERE is_agentic = true
|
||||
AND research_type IS NULL;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get connection to execute raw SQL
|
||||
connection = op.get_bind()
|
||||
|
||||
# Note: This downgrade removes all research agent iteration data
|
||||
# There's no way to perfectly restore the original agent__sub_question data
|
||||
# if it was deleted after this migration
|
||||
|
||||
# Delete all research_agent_iteration_sub_step records that were migrated
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM research_agent_iteration_sub_step
|
||||
USING chat_message
|
||||
WHERE research_agent_iteration_sub_step.primary_question_id = chat_message.id
|
||||
AND chat_message.research_type = 'LEGACY_AGENTIC';
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Delete all research_agent_iteration records that were migrated
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM research_agent_iteration
|
||||
USING chat_message
|
||||
WHERE research_agent_iteration.primary_question_id = chat_message.id
|
||||
AND chat_message.research_type = 'LEGACY_AGENTIC';
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Revert chat_message updates: clear research fields for legacy agentic messages
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET research_type = NULL,
|
||||
research_answer_purpose = NULL
|
||||
WHERE is_agentic = true
|
||||
AND research_type = 'LEGACY_AGENTIC'
|
||||
AND message_type = 'ASSISTANT';
|
||||
"""
|
||||
)
|
||||
)
|
||||
152
backend/alembic/versions/d09fc20a3c66_seed_builtin_tools.py
Normal file
152
backend/alembic/versions/d09fc20a3c66_seed_builtin_tools.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""seed_builtin_tools
|
||||
|
||||
Revision ID: d09fc20a3c66
|
||||
Revises: b7ec9b5b505f
|
||||
Create Date: 2025-09-09 19:32:16.824373
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d09fc20a3c66"
|
||||
down_revision = "b7ec9b5b505f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Tool definitions - core tools that should always be seeded
|
||||
# Names/in_code_tool_id are the same as the class names in the tool_implementations package
|
||||
BUILT_IN_TOOLS = [
|
||||
{
|
||||
"name": "SearchTool",
|
||||
"display_name": "Internal Search",
|
||||
"description": "The Search Action allows the Assistant to search through connected knowledge to help build an answer.",
|
||||
"in_code_tool_id": "SearchTool",
|
||||
},
|
||||
{
|
||||
"name": "ImageGenerationTool",
|
||||
"display_name": "Image Generation",
|
||||
"description": (
|
||||
"The Image Generation Action allows the assistant to use DALL-E 3 or GPT-IMAGE-1 to generate images. "
|
||||
"The action will be used when the user asks the assistant to generate an image."
|
||||
),
|
||||
"in_code_tool_id": "ImageGenerationTool",
|
||||
},
|
||||
{
|
||||
"name": "WebSearchTool",
|
||||
"display_name": "Web Search",
|
||||
"description": (
|
||||
"The Web Search Action allows the assistant "
|
||||
"to perform internet searches for up-to-date information."
|
||||
),
|
||||
"in_code_tool_id": "WebSearchTool",
|
||||
},
|
||||
{
|
||||
"name": "KnowledgeGraphTool",
|
||||
"display_name": "Knowledge Graph Search",
|
||||
"description": (
|
||||
"The Knowledge Graph Search Action allows the assistant to search the "
|
||||
"Knowledge Graph for information. This tool can (for now) only be active in the KG Beta Assistant, "
|
||||
"and it requires the Knowledge Graph to be enabled."
|
||||
),
|
||||
"in_code_tool_id": "KnowledgeGraphTool",
|
||||
},
|
||||
{
|
||||
"name": "OktaProfileTool",
|
||||
"display_name": "Okta Profile",
|
||||
"description": (
|
||||
"The Okta Profile Action allows the assistant to fetch the current user's information from Okta. "
|
||||
"This may include the user's name, email, phone number, address, and other details such as their "
|
||||
"manager and direct reports."
|
||||
),
|
||||
"in_code_tool_id": "OktaProfileTool",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
|
||||
)
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't remove the tools on downgrade since it's totally fine to just
|
||||
# have them around. If we upgrade again, it will be a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add research_answer_purpose to chat_message
|
||||
|
||||
Revision ID: f8a9b2c3d4e5
|
||||
Revises: 5ae8240accb3
|
||||
Create Date: 2025-01-27 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f8a9b2c3d4e5"
|
||||
down_revision = "5ae8240accb3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add research_answer_purpose column to chat_message table
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_answer_purpose", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove research_answer_purpose column from chat_message table
|
||||
op.drop_column("chat_message", "research_answer_purpose")
|
||||
@@ -0,0 +1,69 @@
|
||||
"""remove foreign key constraints from research_agent_iteration_sub_step
|
||||
|
||||
Revision ID: f9b8c7d6e5a4
|
||||
Revises: bd7c3bf8beba
|
||||
Create Date: 2025-01-27 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f9b8c7d6e5a4"
|
||||
down_revision = "bd7c3bf8beba"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing foreign key constraint for parent_question_id
|
||||
op.drop_constraint(
|
||||
"research_agent_iteration_sub_step_parent_question_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Drop the parent_question_id column entirely
|
||||
op.drop_column("research_agent_iteration_sub_step", "parent_question_id")
|
||||
|
||||
# Drop the foreign key constraint for primary_question_id to chat_message.id
|
||||
# (keep the column as it's needed for the composite foreign key)
|
||||
op.drop_constraint(
|
||||
"research_agent_iteration_sub_step_primary_question_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore the foreign key constraint for primary_question_id to chat_message.id
|
||||
op.create_foreign_key(
|
||||
"research_agent_iteration_sub_step_primary_question_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
"chat_message",
|
||||
["primary_question_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Add back the parent_question_id column
|
||||
op.add_column(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column(
|
||||
"parent_question_id",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Restore the foreign key constraint pointing to research_agent_iteration_sub_step.id
|
||||
op.create_foreign_key(
|
||||
"research_agent_iteration_sub_step_parent_question_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
"research_agent_iteration_sub_step",
|
||||
["parent_question_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
@@ -1,133 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.onyx.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# mark as EE for all tasks in this file
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str
|
||||
) -> None:
|
||||
task_id = self.request.id
|
||||
if not task_id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
user_id: UUID | None = None
|
||||
session_id: UUID | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# we generally want to move off this, but keeping for now
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name=OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name=OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
period=None,
|
||||
)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
@@ -135,5 +6,7 @@ celery_app.autodiscover_tasks(
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
"ee.onyx.background.celery.tasks.ttl_management",
|
||||
"ee.onyx.background.celery.tasks.usage_reporting",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ ee_beat_system_tasks: list[dict] = []
|
||||
ee_beat_task_templates: list[dict] = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
@@ -57,7 +57,7 @@ if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
|
||||
@@ -93,7 +93,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
task_logger.error(
|
||||
f"Recieved non-sync CC Pair {cc_pair.id} for external "
|
||||
f"Received non-sync CC Pair {cc_pair.id} for external "
|
||||
f"group sync. Actual access type: {cc_pair.access_type}"
|
||||
)
|
||||
return False
|
||||
|
||||
106
backend/ee/onyx/background/celery/tasks/ttl_management/tasks.py
Normal file
106
backend/ee/onyx/background/celery/tasks/ttl_management/tasks.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.onyx.background.task_name_builders import name_chat_ttl_task
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str
|
||||
) -> None:
|
||||
task_id = self.request.id
|
||||
if not task_id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
user_id: UUID | None = None
|
||||
session_id: UUID | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# we generally want to move off this, but keeping for now
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def generate_usage_report_task(
|
||||
self: Task,
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str | None = None,
|
||||
period_from: str | None = None,
|
||||
period_to: str | None = None,
|
||||
) -> None:
|
||||
"""User-initiated usage report generation task"""
|
||||
# Parse period if provided
|
||||
period = None
|
||||
if period_from and period_to:
|
||||
period = (
|
||||
datetime.fromisoformat(period_from),
|
||||
datetime.fromisoformat(period_to),
|
||||
)
|
||||
|
||||
# Generate the report
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=UUID(user_id) if user_id else None,
|
||||
period=period,
|
||||
)
|
||||
@@ -1,38 +0,0 @@
|
||||
from ee.onyx.server.query_and_chat.models import OneShotQAResponse
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_for_answer_api(
|
||||
packets: ChatPacketStream,
|
||||
) -> OneShotQAResponse:
|
||||
response = OneShotQAResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.docs = packet
|
||||
# Extraneous, provided for backwards compatibility
|
||||
response.rephrase = packet.rephrased_query
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.chat_message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
return response
|
||||
@@ -2,18 +2,14 @@ from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
|
||||
from onyx.access.models import DocExternalAccess # noqa
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import ConnectorCredentialPair # noqa
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
|
||||
# Avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
|
||||
from onyx.access.models import DocExternalAccess # noqa
|
||||
from onyx.db.models import ConnectorCredentialPair # noqa
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
|
||||
|
||||
|
||||
class FetchAllDocumentsFunction(Protocol):
|
||||
@@ -52,20 +48,20 @@ class FetchAllDocumentsIdsFunction(Protocol):
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
"ConnectorCredentialPair",
|
||||
ConnectorCredentialPair,
|
||||
FetchAllDocumentsFunction,
|
||||
FetchAllDocumentsIdsFunction,
|
||||
Optional["IndexingHeartbeatInterface"],
|
||||
Optional[IndexingHeartbeatInterface],
|
||||
],
|
||||
Generator["DocExternalAccess", None, None],
|
||||
Generator[DocExternalAccess, None, None],
|
||||
]
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
str, # tenant_id
|
||||
"ConnectorCredentialPair", # cc_pair
|
||||
ConnectorCredentialPair, # cc_pair
|
||||
],
|
||||
Generator["ExternalUserGroup", None, None],
|
||||
Generator[ExternalUserGroup, None, None],
|
||||
]
|
||||
|
||||
# list of chunks to be censored and the user email. returns censored chunks
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import re
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.runtime.client_request import ClientRequestException # type: ignore
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
@@ -231,6 +234,7 @@ def _get_sharepoint_groups(
|
||||
nonlocal groups, user_emails
|
||||
|
||||
for user in users:
|
||||
logger.debug(f"User: {user.to_json()}")
|
||||
if user.principal_type == USER_PRINCIPAL_TYPE and hasattr(
|
||||
user, "user_principal_name"
|
||||
):
|
||||
@@ -285,7 +289,7 @@ def _get_azuread_groups(
|
||||
|
||||
for member in members:
|
||||
member_data = member.to_json()
|
||||
|
||||
logger.debug(f"Member: {member_data}")
|
||||
# Check for user-specific attributes
|
||||
user_principal_name = member_data.get("userPrincipalName")
|
||||
mail = member_data.get("mail")
|
||||
@@ -366,6 +370,7 @@ def _get_groups_and_members_recursively(
|
||||
client_context: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
groups: set[SharepointGroup],
|
||||
is_group_sync: bool = False,
|
||||
) -> GroupsResult:
|
||||
"""
|
||||
Get all groups and their members recursively.
|
||||
@@ -373,6 +378,7 @@ def _get_groups_and_members_recursively(
|
||||
group_queue: deque[SharepointGroup] = deque(groups)
|
||||
visited_groups: set[str] = set()
|
||||
visited_group_name_to_emails: dict[str, set[str]] = {}
|
||||
found_public_group = False
|
||||
while group_queue:
|
||||
group = group_queue.popleft()
|
||||
if group.login_name in visited_groups:
|
||||
@@ -390,19 +396,35 @@ def _get_groups_and_members_recursively(
|
||||
if group_info:
|
||||
group_queue.extend(group_info)
|
||||
if group.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
|
||||
# if the site is public, we have default groups assigned to it, so we return early
|
||||
if _is_public_login_name(group.login_name):
|
||||
return GroupsResult(groups_to_emails={}, found_public_group=True)
|
||||
|
||||
group_info, user_emails = _get_azuread_groups(
|
||||
graph_client, group.login_name
|
||||
)
|
||||
visited_group_name_to_emails[group.name].update(user_emails)
|
||||
if group_info:
|
||||
group_queue.extend(group_info)
|
||||
try:
|
||||
# if the site is public, we have default groups assigned to it, so we return early
|
||||
if _is_public_login_name(group.login_name):
|
||||
found_public_group = True
|
||||
if not is_group_sync:
|
||||
return GroupsResult(
|
||||
groups_to_emails={}, found_public_group=True
|
||||
)
|
||||
else:
|
||||
# we don't want to sync public groups, so we skip them
|
||||
continue
|
||||
group_info, user_emails = _get_azuread_groups(
|
||||
graph_client, group.login_name
|
||||
)
|
||||
visited_group_name_to_emails[group.name].update(user_emails)
|
||||
if group_info:
|
||||
group_queue.extend(group_info)
|
||||
except ClientRequestException as e:
|
||||
# If the group is not found, we skip it. There is a chance that group is still referenced
|
||||
# in sharepoint but it is removed from Azure AD. There is no actual documentation on this, but based on
|
||||
# our testing we have seen this happen.
|
||||
if e.response is not None and e.response.status_code == 404:
|
||||
logger.warning(f"Group {group.login_name} not found")
|
||||
continue
|
||||
raise e
|
||||
|
||||
return GroupsResult(
|
||||
groups_to_emails=visited_group_name_to_emails, found_public_group=False
|
||||
groups_to_emails=visited_group_name_to_emails,
|
||||
found_public_group=found_public_group,
|
||||
)
|
||||
|
||||
|
||||
@@ -427,6 +449,7 @@ def get_external_access_from_sharepoint(
|
||||
) -> None:
|
||||
nonlocal user_emails, groups
|
||||
for assignment in role_assignments:
|
||||
logger.debug(f"Assignment: {assignment.to_json()}")
|
||||
if assignment.role_definition_bindings:
|
||||
is_limited_access = True
|
||||
for role_definition_binding in assignment.role_definition_bindings:
|
||||
@@ -503,12 +526,19 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
elif site_page:
|
||||
site_url = site_page.get("webUrl")
|
||||
site_pages = client_context.web.lists.get_by_title("Site Pages")
|
||||
client_context.load(site_pages)
|
||||
client_context.execute_query()
|
||||
site_pages.items.get_by_url(site_url).role_assignments.expand(
|
||||
["Member", "RoleDefinitionBindings"]
|
||||
).get_all(page_loaded=add_user_and_group_to_sets).execute_query()
|
||||
# Prefer server-relative URL to avoid OData filters that break on apostrophes
|
||||
server_relative_url = unquote(urlparse(site_url).path)
|
||||
file_obj = client_context.web.get_file_by_server_relative_url(
|
||||
server_relative_url
|
||||
)
|
||||
item = file_obj.listItemAllFields
|
||||
|
||||
sleep_and_retry(
|
||||
item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all(
|
||||
page_loaded=add_user_and_group_to_sets,
|
||||
),
|
||||
"get_external_access_from_sharepoint",
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("No drive item or site page provided")
|
||||
|
||||
@@ -595,13 +625,9 @@ def get_sharepoint_external_groups(
|
||||
"get_sharepoint_external_groups",
|
||||
)
|
||||
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
|
||||
client_context, graph_client, groups
|
||||
client_context, graph_client, groups, is_group_sync=True
|
||||
)
|
||||
|
||||
# We don't have any direct way to check if the site is public, so we check if any public group is present
|
||||
if groups_and_members.found_public_group:
|
||||
return []
|
||||
|
||||
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
|
||||
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
|
||||
azure_ad_groups = sleep_and_retry(
|
||||
|
||||
@@ -17,6 +17,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
from ee.onyx.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
@@ -170,6 +171,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
|
||||
include_router_with_global_prefix_prepended(application, evals_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
|
||||
@@ -8,13 +8,12 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from ee.onyx.db.standard_answer import find_matching_standard_answers
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_REACT_EMOJI
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_messages_by_sessions
|
||||
from onyx.db.chat import get_chat_sessions_by_slack_thread_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import StandardAnswer as StandardAnswerModel
|
||||
from onyx.onyxbot.slack.blocks import get_restate_blocks
|
||||
@@ -81,7 +80,6 @@ def _handle_standard_answers(
|
||||
message_info: SlackMessageInfo,
|
||||
receiver_ids: list[str] | None,
|
||||
slack_channel_config: SlackChannelConfig,
|
||||
prompt: Prompt | None,
|
||||
logger: OnyxLoggingAdapter,
|
||||
client: WebClient,
|
||||
db_session: Session,
|
||||
@@ -161,7 +159,6 @@ def _handle_standard_answers(
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=query_msg.message,
|
||||
token_count=0,
|
||||
message_type=MessageType.USER,
|
||||
@@ -182,7 +179,6 @@ def _handle_standard_answers(
|
||||
chat_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=answer_message,
|
||||
token_count=0,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
@@ -197,7 +193,7 @@ def _handle_standard_answers(
|
||||
db_session.commit()
|
||||
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
emoji=ONYX_BOT_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
|
||||
@@ -16,6 +16,7 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
# saml
|
||||
("/auth/saml/authorize", {"GET"}),
|
||||
("/auth/saml/callback", {"POST"}),
|
||||
("/auth/saml/callback", {"GET"}),
|
||||
("/auth/saml/logout", {"POST"}),
|
||||
]
|
||||
|
||||
|
||||
32
backend/ee/onyx/server/evals/api.py
Normal file
32
backend/ee/onyx/server/evals/api.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from onyx.background.celery.apps.client import celery_app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.models import User
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
from onyx.server.evals.models import EvalRunAck
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/evals")
|
||||
|
||||
|
||||
@router.post("/eval_run", response_model=EvalRunAck)
|
||||
def eval_run(
|
||||
request: EvalConfigurationOptions,
|
||||
user: User = Depends(current_cloud_superuser),
|
||||
) -> EvalRunAck:
|
||||
"""
|
||||
Run an evaluation with the given message and optional dataset.
|
||||
This endpoint requires a valid API key for authentication.
|
||||
"""
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.EVAL_RUN_TASK,
|
||||
kwargs={
|
||||
"configuration_dict": request.model_dump(),
|
||||
},
|
||||
)
|
||||
return EvalRunAck(success=True)
|
||||
@@ -1,43 +1,22 @@
|
||||
import re
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import AgentAnswer
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuery
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
|
||||
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import OptionalSearchSetting
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
@@ -46,7 +25,6 @@ from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -55,180 +33,6 @@ logger = setup_logger()
|
||||
router = APIRouter(prefix="/chat")
|
||||
|
||||
|
||||
def _get_final_context_doc_indices(
|
||||
final_context_docs: list[LlmDoc] | None,
|
||||
top_docs: list[SavedSearchDoc] | None,
|
||||
) -> list[int] | None:
|
||||
"""
|
||||
this function returns a list of indices of the simple search docs
|
||||
that were actually fed to the LLM.
|
||||
"""
|
||||
if final_context_docs is None or top_docs is None:
|
||||
return None
|
||||
|
||||
final_context_doc_ids = {doc.document_id for doc in final_context_docs}
|
||||
return [
|
||||
i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids
|
||||
]
|
||||
|
||||
|
||||
def _convert_packet_stream_to_response(
|
||||
packets: ChatPacketStream,
|
||||
chat_session_id: UUID,
|
||||
) -> ChatBasicResponse:
|
||||
response = ChatBasicResponse()
|
||||
final_context_docs: list[LlmDoc] = []
|
||||
|
||||
answer = ""
|
||||
|
||||
# accumulate stream data with these dicts
|
||||
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
|
||||
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
|
||||
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.top_documents = packet.top_documents
|
||||
|
||||
# This is a no-op if agent_sub_questions hasn't already been filled
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if id in agent_sub_questions:
|
||||
agent_sub_questions[id].document_ids = [
|
||||
saved_search_doc.document_id
|
||||
for saved_search_doc in packet.top_documents
|
||||
]
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
|
||||
# TODO: deprecate `llm_chunks_indices`
|
||||
response.llm_chunks_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, FinalUsedContextDocsResponse):
|
||||
final_context_docs = packet.final_context_docs
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.cited_documents = {
|
||||
citation.citation_num: citation.document_id
|
||||
for citation in packet.citations
|
||||
}
|
||||
# agentic packets
|
||||
elif isinstance(packet, SubQuestionPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_sub_questions.get(id) is None:
|
||||
agent_sub_questions[id] = AgentSubQuestion(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_question=packet.sub_question,
|
||||
document_ids=[],
|
||||
)
|
||||
else:
|
||||
agent_sub_questions[id].sub_question += packet.sub_question
|
||||
|
||||
elif isinstance(packet, AgentAnswerPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_answers.get(id) is None:
|
||||
agent_answers[id] = AgentAnswer(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
answer=packet.answer_piece,
|
||||
answer_type=packet.answer_type,
|
||||
)
|
||||
else:
|
||||
agent_answers[id].answer += packet.answer_piece
|
||||
elif isinstance(packet, SubQueryPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
sub_query_id = (
|
||||
packet.level,
|
||||
packet.level_question_num,
|
||||
packet.query_id,
|
||||
)
|
||||
if agent_sub_queries.get(sub_query_id) is None:
|
||||
agent_sub_queries[sub_query_id] = AgentSubQuery(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_query=packet.sub_query,
|
||||
query_id=packet.query_id,
|
||||
)
|
||||
else:
|
||||
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
|
||||
elif isinstance(packet, ExtendedToolResponse):
|
||||
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
|
||||
logger.warning(
|
||||
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
|
||||
)
|
||||
elif isinstance(packet, RefinedAnswerImprovement):
|
||||
response.agent_refined_answer_improvement = (
|
||||
packet.refined_answer_improvement
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
|
||||
)
|
||||
|
||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||
final_context_docs, response.top_documents
|
||||
)
|
||||
|
||||
# organize / sort agent metadata for output
|
||||
if len(agent_sub_questions) > 0:
|
||||
response.agent_sub_questions = cast(
|
||||
dict[int, list[AgentSubQuestion]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
|
||||
)
|
||||
|
||||
if len(agent_answers) > 0:
|
||||
# return the agent_level_answer from the first level or the last one depending
|
||||
# on agent_refined_answer_improvement
|
||||
response.agent_answers = cast(
|
||||
dict[int, list[AgentAnswer]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_answers),
|
||||
)
|
||||
if response.agent_answers:
|
||||
selected_answer_level = (
|
||||
0
|
||||
if not response.agent_refined_answer_improvement
|
||||
else len(response.agent_answers) - 1
|
||||
)
|
||||
level_answers = response.agent_answers[selected_answer_level]
|
||||
for level_answer in level_answers:
|
||||
if level_answer.answer_type != "agent_level_answer":
|
||||
continue
|
||||
|
||||
answer = level_answer.answer
|
||||
break
|
||||
|
||||
if len(agent_sub_queries) > 0:
|
||||
# subqueries are often emitted with trailing whitespace ... clean it up here
|
||||
# perhaps fix at the source?
|
||||
for v in agent_sub_queries.values():
|
||||
v.sub_query = v.sub_query.strip()
|
||||
|
||||
response.agent_sub_queries = (
|
||||
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
|
||||
)
|
||||
|
||||
response.answer = answer
|
||||
if answer:
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
||||
response.chat_session_id = chat_session_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def remove_answer_citations(answer: str) -> str:
|
||||
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
|
||||
|
||||
return re.sub(pattern, "", answer)
|
||||
|
||||
|
||||
@router.post("/send-message-simple-api")
|
||||
def handle_simplified_chat_message(
|
||||
chat_message_req: BasicCreateChatMessageRequest,
|
||||
@@ -289,7 +93,6 @@ def handle_simplified_chat_message(
|
||||
parent_message_id=parent_message.id,
|
||||
message=chat_message_req.message,
|
||||
file_descriptors=[],
|
||||
prompt_id=None,
|
||||
search_doc_ids=chat_message_req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
@@ -310,7 +113,7 @@ def handle_simplified_chat_message(
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return _convert_packet_stream_to_response(packets, chat_session_id)
|
||||
return gather_stream(packets)
|
||||
|
||||
|
||||
@router.post("/send-message-simple-with-history")
|
||||
@@ -377,7 +180,6 @@ def handle_send_message_simple_with_history(
|
||||
chat_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=chat_message,
|
||||
prompt_id=req.prompt_id,
|
||||
message=msg.message,
|
||||
token_count=len(llm_tokenizer.encode(msg.message)),
|
||||
message_type=msg.role,
|
||||
@@ -410,7 +212,6 @@ def handle_send_message_simple_with_history(
|
||||
parent_message_id=chat_message.id,
|
||||
message=query,
|
||||
file_descriptors=[],
|
||||
prompt_id=req.prompt_id,
|
||||
search_doc_ids=req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
@@ -430,4 +231,4 @@ def handle_send_message_simple_with_history(
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return _convert_packet_stream_to_response(packets, chat_session.id)
|
||||
return gather_stream(packets)
|
||||
|
||||
@@ -6,10 +6,8 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
@@ -17,8 +15,9 @@ from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
@@ -74,7 +73,6 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None
|
||||
persona_id: int
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
query_override: str | None = None
|
||||
@@ -156,33 +154,6 @@ class AgentSubQuery(SubQuestionIdentifier):
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
answer_citationless: str | None = None
|
||||
|
||||
top_documents: list[SavedSearchDoc] | None = None
|
||||
|
||||
error_msg: str | None = None
|
||||
message_id: int | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
final_context_doc_indices: list[int] | None = None
|
||||
# this is a map of the citation number to the document id
|
||||
cited_documents: dict[int, str] | None = None
|
||||
|
||||
# FOR BACKWARDS COMPATIBILITY
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
||||
# agentic fields
|
||||
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
|
||||
agent_answers: dict[int, list[AgentAnswer]] | None = None
|
||||
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
|
||||
agent_refined_answer_improvement: bool | None = None
|
||||
|
||||
# Chat session ID for tracking conversation continuity
|
||||
chat_session_id: UUID | None = None
|
||||
|
||||
|
||||
class OneShotQARequest(ChunkContext):
|
||||
# Supports simplier APIs that don't deal with chat histories or message edits
|
||||
# Easier APIs to work with for developers
|
||||
@@ -190,10 +161,8 @@ class OneShotQARequest(ChunkContext):
|
||||
persona_id: int | None = None
|
||||
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
return_contexts: bool = False
|
||||
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# can be used if the message sent to the LLM / query should not be the same
|
||||
@@ -210,11 +179,9 @@ class OneShotQARequest(ChunkContext):
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
elif self.persona_override_config is not None and (
|
||||
self.persona_id is not None or self.prompt_id is not None
|
||||
):
|
||||
elif self.persona_override_config is not None and (self.persona_id is not None):
|
||||
raise ValueError(
|
||||
"If persona_override_config is set, persona_id and prompt_id cannot be set"
|
||||
"If persona_override_config is set, persona_id cannot be set"
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -225,6 +192,5 @@ class OneShotQAResponse(BaseModel):
|
||||
rephrase: str | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
|
||||
@@ -8,7 +8,6 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.chat.process_message import gather_stream_for_answer_api
|
||||
from ee.onyx.onyxbot.slack.handlers.handle_standard_answers import (
|
||||
oneoff_standard_answers,
|
||||
)
|
||||
@@ -20,8 +19,10 @@ from ee.onyx.server.query_and_chat.models import StandardAnswerResponse
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import prepare_chat_message_request
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
from onyx.context.search.models import SavedSearchDocWithContent
|
||||
@@ -30,7 +31,6 @@ from onyx.context.search.pipeline import SearchPipeline
|
||||
from onyx.context.search.utils import dedupe_documents
|
||||
from onyx.context.search.utils import drop_llm_indices
|
||||
from onyx.context.search.utils import relevant_sections_to_indices
|
||||
from onyx.db.chat import get_prompt_by_id
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
@@ -39,6 +39,7 @@ from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -140,7 +141,7 @@ def get_answer_stream(
|
||||
query_request: OneShotQARequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatPacketStream:
|
||||
) -> AnswerStream:
|
||||
query = query_request.messages[0].message
|
||||
logger.notice(f"Received query for Answer API: {query}")
|
||||
|
||||
@@ -150,14 +151,6 @@ def get_answer_stream(
|
||||
):
|
||||
raise KeyError("Must provide persona ID or Persona Config")
|
||||
|
||||
prompt = None
|
||||
if query_request.prompt_id is not None:
|
||||
prompt = get_prompt_by_id(
|
||||
prompt_id=query_request.prompt_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
persona_info: Persona | PersonaOverrideConfig | None = None
|
||||
if query_request.persona_override_config is not None:
|
||||
persona_info = query_request.persona_override_config
|
||||
@@ -192,7 +185,6 @@ def get_answer_stream(
|
||||
user=user,
|
||||
persona_id=query_request.persona_id,
|
||||
persona_override_config=query_request.persona_override_config,
|
||||
prompt=prompt,
|
||||
message_ts_to_respond_to=None,
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
@@ -205,7 +197,6 @@ def get_answer_stream(
|
||||
new_msg_req=request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
include_contexts=query_request.return_contexts,
|
||||
)
|
||||
|
||||
return packets
|
||||
@@ -219,12 +210,28 @@ def get_answer_with_citation(
|
||||
) -> OneShotQAResponse:
|
||||
try:
|
||||
packets = get_answer_stream(request, user, db_session)
|
||||
answer = gather_stream_for_answer_api(packets)
|
||||
answer = gather_stream(packets)
|
||||
|
||||
if answer.error_msg:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
return answer
|
||||
return OneShotQAResponse(
|
||||
answer=answer.answer,
|
||||
chat_message_id=answer.message_id,
|
||||
error_msg=answer.error_msg,
|
||||
citations=[
|
||||
CitationInfo(citation_num=i, document_id=doc_id)
|
||||
for i, doc_id in answer.cited_documents.items()
|
||||
],
|
||||
docs=QADocsResponse(
|
||||
top_documents=answer.top_documents,
|
||||
predicted_flow=None,
|
||||
predicted_search=None,
|
||||
applied_source_filters=None,
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=0.0,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_answer_with_citation: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="An internal server error occurred")
|
||||
|
||||
@@ -182,7 +182,6 @@ def admin_get_chat_sessions(
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
|
||||
@@ -12,11 +12,13 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.usage_export import get_all_usage_reports
|
||||
from ee.onyx.db.usage_export import get_usage_report_data
|
||||
from ee.onyx.db.usage_export import UsageReportMetadata
|
||||
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.constants import STANDARD_CHUNK_SIZE
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -26,24 +28,31 @@ class GenerateUsageReportParams(BaseModel):
|
||||
period_to: str | None = None
|
||||
|
||||
|
||||
@router.post("/admin/generate-usage-report")
|
||||
@router.post("/admin/usage-report", status_code=204)
|
||||
def generate_report(
|
||||
params: GenerateUsageReportParams,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UsageReportMetadata:
|
||||
period = None
|
||||
) -> None:
|
||||
# Validate period parameters
|
||||
if params.period_from and params.period_to:
|
||||
try:
|
||||
period = (
|
||||
datetime.fromisoformat(params.period_from),
|
||||
datetime.fromisoformat(params.period_to),
|
||||
)
|
||||
datetime.fromisoformat(params.period_from)
|
||||
datetime.fromisoformat(params.period_to)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
new_report = create_new_usage_report(db_session, user.id if user else None, period)
|
||||
return new_report
|
||||
tenant_id = get_current_tenant_id()
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": str(user.id) if user else None,
|
||||
"period_from": params.period_from,
|
||||
"period_to": params.period_to,
|
||||
},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/admin/usage-report/{report_name}")
|
||||
@@ -54,7 +63,7 @@ def read_usage_report(
|
||||
) -> Response:
|
||||
try:
|
||||
file = get_usage_report_data(report_name)
|
||||
except ValueError as e:
|
||||
except (ValueError, RuntimeError) as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
def iterfile() -> Generator[bytes, None, None]:
|
||||
|
||||
@@ -110,7 +110,6 @@ async def upsert_saml_user(email: str) -> User:
|
||||
|
||||
|
||||
async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
|
||||
form_data = await request.form()
|
||||
if request.client is None:
|
||||
raise ValueError("Invalid request for SAML")
|
||||
|
||||
@@ -125,14 +124,27 @@ async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
|
||||
"post_data": {},
|
||||
"get_data": {},
|
||||
}
|
||||
|
||||
# Handle query parameters (for GET requests)
|
||||
if request.query_params:
|
||||
rv["get_data"] = (request.query_params,)
|
||||
if "SAMLResponse" in form_data:
|
||||
SAMLResponse = form_data["SAMLResponse"]
|
||||
rv["post_data"]["SAMLResponse"] = SAMLResponse
|
||||
if "RelayState" in form_data:
|
||||
RelayState = form_data["RelayState"]
|
||||
rv["post_data"]["RelayState"] = RelayState
|
||||
rv["get_data"] = dict(request.query_params)
|
||||
|
||||
# Handle form data (for POST requests)
|
||||
if request.method == "POST":
|
||||
form_data = await request.form()
|
||||
if "SAMLResponse" in form_data:
|
||||
SAMLResponse = form_data["SAMLResponse"]
|
||||
rv["post_data"]["SAMLResponse"] = SAMLResponse
|
||||
if "RelayState" in form_data:
|
||||
RelayState = form_data["RelayState"]
|
||||
rv["post_data"]["RelayState"] = RelayState
|
||||
else:
|
||||
# For GET requests, check if SAMLResponse is in query params
|
||||
if "SAMLResponse" in request.query_params:
|
||||
rv["get_data"]["SAMLResponse"] = request.query_params["SAMLResponse"]
|
||||
if "RelayState" in request.query_params:
|
||||
rv["get_data"]["RelayState"] = request.query_params["RelayState"]
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
@@ -148,10 +160,27 @@ async def saml_login(request: Request) -> SAMLAuthorizeResponse:
|
||||
return SAMLAuthorizeResponse(authorization_url=callback_url)
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def saml_login_callback_get(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Handle SAML callback via HTTP-Redirect binding (GET request)"""
|
||||
return await _process_saml_callback(request, db_session)
|
||||
|
||||
|
||||
@router.post("/callback")
|
||||
async def saml_login_callback(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Handle SAML callback via HTTP-POST binding (POST request)"""
|
||||
return await _process_saml_callback(request, db_session)
|
||||
|
||||
|
||||
async def _process_saml_callback(
|
||||
request: Request,
|
||||
db_session: Session,
|
||||
) -> Response:
|
||||
req = await prepare_from_fastapi_request(request)
|
||||
auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR)
|
||||
|
||||
@@ -131,32 +131,35 @@ def _seed_llms(
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
if personas:
|
||||
logger.notice("Seeding Personas")
|
||||
for persona in personas:
|
||||
if not persona.prompt_ids:
|
||||
raise ValueError(
|
||||
f"Invalid Persona with name {persona.name}; no prompts exist"
|
||||
try:
|
||||
for persona in personas:
|
||||
upsert_persona(
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
starter_messages=persona.starter_messages,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
tool_ids=persona.tool_ids,
|
||||
display_priority=persona.display_priority,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
upsert_persona(
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
prompt_ids=persona.prompt_ids,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
starter_messages=persona.starter_messages,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
tool_ids=persona.tool_ids,
|
||||
display_priority=persona.display_priority,
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to seed personas.")
|
||||
raise
|
||||
|
||||
|
||||
def _seed_settings(settings: Settings) -> None:
|
||||
|
||||
@@ -1,34 +1,5 @@
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
||||
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
|
||||
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
DEFAULT_VERTEX_MODEL = "text-embedding-005"
|
||||
|
||||
|
||||
class EmbeddingModelTextType:
|
||||
PROVIDER_TEXT_TYPE_MAP = {
|
||||
EmbeddingProvider.COHERE: {
|
||||
EmbedTextType.QUERY: "search_query",
|
||||
EmbedTextType.PASSAGE: "search_document",
|
||||
},
|
||||
EmbeddingProvider.VOYAGE: {
|
||||
EmbedTextType.QUERY: "query",
|
||||
EmbedTextType.PASSAGE: "document",
|
||||
},
|
||||
EmbeddingProvider.GOOGLE: {
|
||||
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
|
||||
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
||||
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
||||
|
||||
|
||||
class GPUStatus:
|
||||
|
||||
@@ -1,55 +1,30 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import aioboto3 # type: ignore
|
||||
import httpx
|
||||
import openai
|
||||
import vertexai # type: ignore
|
||||
import voyageai # type: ignore
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import aembedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingModel # type: ignore
|
||||
|
||||
from model_server.constants import DEFAULT_COHERE_MODEL
|
||||
from model_server.constants import DEFAULT_OPENAI_MODEL
|
||||
from model_server.constants import DEFAULT_VERTEX_MODEL
|
||||
from model_server.constants import DEFAULT_VOYAGE_MODEL
|
||||
from model_server.constants import EmbeddingModelTextType
|
||||
from model_server.constants import EmbeddingProvider
|
||||
from model_server.utils import pass_aws_key
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODEL: Optional["CrossEncoder"] = None
|
||||
|
||||
@@ -57,338 +32,56 @@ _RERANK_MODEL: Optional["CrossEncoder"] = None
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
# OpenAI only allows 2048 embeddings to be computed at once
|
||||
_OPENAI_MAX_INPUT_LEN = 2048
|
||||
# Cohere allows up to 96 embeddings in a single embedding calling
|
||||
_COHERE_MAX_INPUT_LEN = 96
|
||||
|
||||
# Authentication error string constants
|
||||
_AUTH_ERROR_401 = "401"
|
||||
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
|
||||
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
|
||||
_AUTH_ERROR_PERMISSION = "permission"
|
||||
|
||||
|
||||
def is_authentication_error(error: Exception) -> bool:
|
||||
"""Check if an exception is related to authentication issues.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
bool: True if the error appears to be authentication-related
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
return (
|
||||
_AUTH_ERROR_401 in error_str
|
||||
or _AUTH_ERROR_UNAUTHORIZED in error_str
|
||||
or _AUTH_ERROR_INVALID_API_KEY in error_str
|
||||
or _AUTH_ERROR_PERMISSION in error_str
|
||||
)
|
||||
|
||||
|
||||
def format_embedding_error(
|
||||
error: Exception,
|
||||
service_name: str,
|
||||
model: str | None,
|
||||
provider: EmbeddingProvider,
|
||||
sanitized_api_key: str | None = None,
|
||||
status_code: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a standardized error string for embedding errors.
|
||||
"""
|
||||
detail = f"Status {status_code}" if status_code else f"{type(error)}"
|
||||
|
||||
return (
|
||||
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {provider} "
|
||||
f"API Key: {sanitized_api_key} "
|
||||
f"Exception: {error}"
|
||||
)
|
||||
|
||||
|
||||
# Custom exception for authentication errors
|
||||
class AuthenticationError(Exception):
|
||||
"""Raised when authentication fails with a provider."""
|
||||
|
||||
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
|
||||
self.provider = provider
|
||||
self.message = message
|
||||
super().__init__(f"{provider} authentication failed: {message}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
provider: EmbeddingProvider,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.api_version = api_version
|
||||
self.timeout = timeout
|
||||
self.http_client = httpx.AsyncClient(timeout=timeout)
|
||||
self._closed = False
|
||||
self.sanitized_api_key = api_key[:4] + "********" + api_key[-4:]
|
||||
|
||||
async def _embed_openai(
|
||||
self, texts: list[str], model: str | None, reduced_dimension: int | None
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
# Use the OpenAI specific timeout for this one
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
|
||||
)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
|
||||
async def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
client = CohereAsyncClient(api_key=self.api_key)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
|
||||
# Does not use the same tokenizer as the Onyx API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = await client.embed(
|
||||
texts=text_batch,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncate="END",
|
||||
)
|
||||
final_embeddings.extend(cast(list[Embedding], response.embeddings))
|
||||
return final_embeddings
|
||||
|
||||
async def _embed_voyage(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
client = voyageai.AsyncClient(
|
||||
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
|
||||
)
|
||||
|
||||
response = await client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncation=True,
|
||||
)
|
||||
return response.embeddings
|
||||
|
||||
async def _embed_azure(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[Embedding]:
|
||||
response = await aembedding(
|
||||
model=model,
|
||||
input=texts,
|
||||
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
||||
api_key=self.api_key,
|
||||
api_base=self.api_url,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
embeddings = [embedding["embedding"] for embedding in response.data]
|
||||
return embeddings
|
||||
|
||||
async def _embed_vertex(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.api_key)
|
||||
)
|
||||
project_id = json.loads(self.api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
client = TextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
|
||||
|
||||
# Split into batches of 25 texts
|
||||
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
||||
batches = [
|
||||
inputs[i : i + max_texts_per_batch]
|
||||
for i in range(0, len(inputs), max_texts_per_batch)
|
||||
]
|
||||
|
||||
# Dispatch all embedding calls asynchronously at once
|
||||
tasks = [
|
||||
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
|
||||
]
|
||||
|
||||
# Wait for all tasks to complete in parallel
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return [embedding.values for batch in results for embedding in batch]
|
||||
|
||||
async def _embed_litellm_proxy(
|
||||
self, texts: list[str], model_name: str | None
|
||||
) -> list[Embedding]:
|
||||
if not model_name:
|
||||
raise ValueError("Model name is required for LiteLLM proxy embedding.")
|
||||
|
||||
if not self.api_url:
|
||||
raise ValueError("API URL is required for LiteLLM proxy embedding.")
|
||||
|
||||
headers = (
|
||||
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
|
||||
)
|
||||
|
||||
response = await self.http_client.post(
|
||||
self.api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": texts,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
async def embed(
|
||||
self,
|
||||
*,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
) -> list[Embedding]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except openai.AuthenticationError:
|
||||
raise AuthenticationError(provider="OpenAI")
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e,
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
sanitized_api_key=self.sanitized_api_key,
|
||||
status_code=e.response.status_code,
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
except Exception as e:
|
||||
if is_authentication_error(e):
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e,
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
sanitized_api_key=self.sanitized_api_key,
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
api_key: str,
|
||||
provider: EmbeddingProvider,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
) -> "CloudEmbedding":
|
||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||
return CloudEmbedding(api_key, provider, api_url, api_version)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Explicitly close the client."""
|
||||
if not self._closed:
|
||||
await self.http_client.aclose()
|
||||
self._closed = True
|
||||
|
||||
async def __aenter__(self) -> "CloudEmbedding":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Finalizer to warn about unclosed clients."""
|
||||
if not self._closed:
|
||||
logger.warning(
|
||||
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name: str,
|
||||
max_context_length: int,
|
||||
) -> "SentenceTransformer":
|
||||
"""
|
||||
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
|
||||
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
|
||||
"""
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
global _GLOBAL_MODELS_DICT # A dictionary to store models
|
||||
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
|
||||
"""
|
||||
Build RoPE cos/sin caches once on the final device/dtype so later forwards only read.
|
||||
Works by calling the underlying HF model directly with dummy IDs/attention.
|
||||
"""
|
||||
try:
|
||||
# ensure > max seq after tokenization
|
||||
# Ideally we would use the saved tokenizer, but whatever it's ok
|
||||
# we'll make an assumption about tokenization here
|
||||
long_text = "x " * (target_len * 2)
|
||||
_ = st_model.encode(
|
||||
[long_text],
|
||||
batch_size=1,
|
||||
convert_to_tensor=True,
|
||||
show_progress_bar=False,
|
||||
normalize_embeddings=False,
|
||||
)
|
||||
logger.info("RoPE pre-warm successful")
|
||||
except Exception as e:
|
||||
logger.warning(f"RoPE pre-warm skipped/failed: {e}")
|
||||
|
||||
global _GLOBAL_MODELS_DICT
|
||||
|
||||
if model_name not in _GLOBAL_MODELS_DICT:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
# Some model architectures that aren't built into the Transformers or Sentence
|
||||
# Transformer need to be downloaded to be loaded locally. This does not mean
|
||||
# data is sent to remote servers for inference, however the remote code can
|
||||
# be fairly arbitrary so only use trusted models
|
||||
model = SentenceTransformer(
|
||||
model_name_or_path=model_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model.max_seq_length = max_context_length
|
||||
_prewarm_rope(model, max_context_length)
|
||||
_GLOBAL_MODELS_DICT[model_name] = model
|
||||
elif max_context_length != _GLOBAL_MODELS_DICT[model_name].max_seq_length:
|
||||
_GLOBAL_MODELS_DICT[model_name].max_seq_length = max_context_length
|
||||
else:
|
||||
model = _GLOBAL_MODELS_DICT[model_name]
|
||||
if max_context_length != model.max_seq_length:
|
||||
model.max_seq_length = max_context_length
|
||||
prev = getattr(model, "_rope_prewarmed_to", 0)
|
||||
if max_context_length > int(prev or 0):
|
||||
_prewarm_rope(model, max_context_length)
|
||||
|
||||
return _GLOBAL_MODELS_DICT[model_name]
|
||||
|
||||
@@ -404,20 +97,34 @@ def get_local_reranking_model(
|
||||
return _RERANK_MODEL
|
||||
|
||||
|
||||
ENCODING_RETRIES = 3
|
||||
ENCODING_RETRY_DELAY = 0.1
|
||||
|
||||
|
||||
def _concurrent_embedding(
|
||||
texts: list[str], model: "SentenceTransformer", normalize_embeddings: bool
|
||||
) -> Any:
|
||||
"""Synchronous wrapper for concurrent_embedding to use with run_in_executor."""
|
||||
for _ in range(ENCODING_RETRIES):
|
||||
try:
|
||||
return model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
except RuntimeError as e:
|
||||
# There is a concurrency bug in the SentenceTransformer library that causes
|
||||
# the model to fail to encode texts. It's pretty rare and we want to allow
|
||||
# concurrent embedding, hence we retry (the specific error is
|
||||
# "RuntimeError: Already borrowed" and occurs in the transformers library)
|
||||
logger.error(f"Error encoding texts, retrying: {e}")
|
||||
time.sleep(ENCODING_RETRY_DELAY)
|
||||
return model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
async def embed_text(
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None,
|
||||
deployment_name: str | None,
|
||||
max_context_length: int,
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
prefix: str | None,
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
reduced_dimension: int | None,
|
||||
gpu_type: str = "UNKNOWN",
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
@@ -434,52 +141,10 @@ async def embed_text(
|
||||
for text in texts:
|
||||
total_chars += len(text)
|
||||
|
||||
if provider_type is not None:
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
|
||||
)
|
||||
# Only local models should call this function now
|
||||
# API providers should go directly to API server
|
||||
|
||||
if api_key is None:
|
||||
logger.error("API key not provided for cloud model")
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
|
||||
if prefix:
|
||||
logger.warning("Prefix provided for cloud model, which is not supported")
|
||||
raise ValueError(
|
||||
"Prefix string is not valid for cloud models. "
|
||||
"Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
async with CloudEmbedding(
|
||||
api_key=api_key,
|
||||
provider=provider_type,
|
||||
api_url=api_url,
|
||||
api_version=api_version,
|
||||
) as cloud_model:
|
||||
embeddings = await cloud_model.embed(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
reduced_dimension=reduced_dimension,
|
||||
)
|
||||
|
||||
if any(embedding is None for embedding in embeddings):
|
||||
error_message = "Embeddings contain None values\n"
|
||||
error_message += "Corresponding texts:\n"
|
||||
error_message += "\n".join(texts)
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"event=embedding_provider "
|
||||
f"texts={len(texts)} "
|
||||
f"chars={total_chars} "
|
||||
f"provider={provider_type} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
elif model_name is not None:
|
||||
if model_name is not None:
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
|
||||
)
|
||||
@@ -492,8 +157,8 @@ async def embed_text(
|
||||
# Run CPU-bound embedding in a thread pool
|
||||
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
lambda: _concurrent_embedding(
|
||||
prefixed_texts, local_model, normalize_embeddings
|
||||
),
|
||||
)
|
||||
embeddings = [
|
||||
@@ -515,10 +180,8 @@ async def embed_text(
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
logger.error("Model name not specified for embedding")
|
||||
raise ValueError("Model name must be provided to run embeddings.")
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -533,77 +196,6 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
|
||||
)
|
||||
|
||||
|
||||
async def cohere_rerank_api(
|
||||
query: str, docs: list[str], model_name: str, api_key: str
|
||||
) -> list[float]:
|
||||
cohere_client = CohereAsyncClient(api_key=api_key)
|
||||
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
|
||||
results = response.results
|
||||
sorted_results = sorted(results, key=lambda item: item.index)
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
async def cohere_rerank_aws(
|
||||
query: str,
|
||||
docs: list[str],
|
||||
model_name: str,
|
||||
region_name: str,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
) -> list[float]:
|
||||
session = aioboto3.Session(
|
||||
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
|
||||
)
|
||||
async with session.client(
|
||||
"bedrock-runtime", region_name=region_name
|
||||
) as bedrock_client:
|
||||
body = json.dumps(
|
||||
{
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"api_version": 2,
|
||||
}
|
||||
)
|
||||
# Invoke the Bedrock model asynchronously
|
||||
response = await bedrock_client.invoke_model(
|
||||
modelId=model_name,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
body=body,
|
||||
)
|
||||
|
||||
# Read the response asynchronously
|
||||
response_body = json.loads(await response["body"].read())
|
||||
|
||||
# Extract and sort the results
|
||||
results = response_body.get("results", [])
|
||||
sorted_results = sorted(results, key=lambda item: item["index"])
|
||||
|
||||
return [result["relevance_score"] for result in sorted_results]
|
||||
|
||||
|
||||
async def litellm_rerank(
|
||||
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
|
||||
) -> list[float]:
|
||||
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [
|
||||
item["relevance_score"]
|
||||
for item in sorted(result["results"], key=lambda x: x["index"])
|
||||
]
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def route_bi_encoder_embed(
|
||||
request: Request,
|
||||
@@ -615,6 +207,13 @@ async def route_bi_encoder_embed(
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
||||
) -> EmbedResponse:
|
||||
# Only local models should use this endpoint - API providers should make direct API calls
|
||||
if embed_request.provider_type is not None:
|
||||
raise ValueError(
|
||||
f"Model server embedding endpoint should only be used for local models. "
|
||||
f"API provider '{embed_request.provider_type}' should make direct API calls instead."
|
||||
)
|
||||
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
|
||||
@@ -632,26 +231,12 @@ async def process_embed_request(
|
||||
embeddings = await embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
deployment_name=embed_request.deployment_name,
|
||||
max_context_length=embed_request.max_context_length,
|
||||
normalize_embeddings=embed_request.normalize_embeddings,
|
||||
api_key=embed_request.api_key,
|
||||
provider_type=embed_request.provider_type,
|
||||
text_type=embed_request.text_type,
|
||||
api_url=embed_request.api_url,
|
||||
api_version=embed_request.api_version,
|
||||
reduced_dimension=embed_request.reduced_dimension,
|
||||
prefix=prefix,
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except AuthenticationError as e:
|
||||
# Handle authentication errors consistently
|
||||
logger.error(f"Authentication error: {e.provider}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Authentication failed: {e.message}",
|
||||
)
|
||||
except RateLimitError as e:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
@@ -669,6 +254,13 @@ async def process_embed_request(
|
||||
@router.post("/cross-encoder-scores")
|
||||
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
|
||||
"""Cross encoders can be purely black box from the app perspective"""
|
||||
# Only local models should use this endpoint - API providers should make direct API calls
|
||||
if rerank_request.provider_type is not None:
|
||||
raise ValueError(
|
||||
f"Model server reranking endpoint should only be used for local models. "
|
||||
f"API provider '{rerank_request.provider_type}' should make direct API calls instead."
|
||||
)
|
||||
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
@@ -680,55 +272,13 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
raise ValueError("Empty documents cannot be reranked.")
|
||||
|
||||
try:
|
||||
if rerank_request.provider_type is None:
|
||||
sim_scores = await local_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
elif rerank_request.provider_type == RerankerProvider.LITELLM:
|
||||
if rerank_request.api_url is None:
|
||||
raise ValueError("API URL is required for LiteLLM reranking.")
|
||||
|
||||
sim_scores = await litellm_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
api_url=rerank_request.api_url,
|
||||
model_name=rerank_request.model_name,
|
||||
api_key=rerank_request.api_key,
|
||||
)
|
||||
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Cohere Rerank Requires an API Key")
|
||||
sim_scores = await cohere_rerank_api(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
api_key=rerank_request.api_key,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
elif rerank_request.provider_type == RerankerProvider.BEDROCK:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Bedrock Rerank Requires an API Key")
|
||||
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
|
||||
rerank_request.api_key
|
||||
)
|
||||
sim_scores = await cohere_rerank_aws(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
region_name=aws_region,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
|
||||
# At this point, provider_type is None, so handle local reranking
|
||||
sim_scores = await local_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during reranking process:\n{str(e)}")
|
||||
|
||||
@@ -34,8 +34,8 @@ from shared_configs.configs import SENTRY_DSN
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
|
||||
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
|
||||
HF_CACHE_PATH = Path(".cache/huggingface")
|
||||
TEMP_HF_CACHE_PATH = Path(".cache/temp_huggingface")
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
|
||||
@@ -70,32 +70,3 @@ def get_gpu_type() -> str:
|
||||
return GPUStatus.MAC_MPS
|
||||
|
||||
return GPUStatus.NONE
|
||||
|
||||
|
||||
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
|
||||
"""Parse AWS API key string into components.
|
||||
|
||||
Args:
|
||||
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
|
||||
|
||||
Returns:
|
||||
Tuple of (access_key, secret_key, region)
|
||||
|
||||
Raises:
|
||||
ValueError: If key format is invalid
|
||||
"""
|
||||
if not api_key.startswith("aws"):
|
||||
raise ValueError("API key must start with 'aws' prefix")
|
||||
|
||||
parts = api_key.split("_")
|
||||
if len(parts) != 4:
|
||||
raise ValueError(
|
||||
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
|
||||
"this is an onyx specific format for formatting the aws secrets for bedrock"
|
||||
)
|
||||
|
||||
try:
|
||||
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
|
||||
return aws_access_key_id, aws_secret_access_key, aws_region
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse AWS key components: {str(e)}")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
@@ -10,6 +11,7 @@ from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.db.document import get_access_info_for_document
|
||||
from onyx.db.document import get_access_info_for_documents
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -124,3 +126,25 @@ def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> b
|
||||
),
|
||||
)
|
||||
return _source_should_fetch_permissions_during_indexing_func(source)
|
||||
|
||||
|
||||
def get_access_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = (
|
||||
db_session.query(UserFile)
|
||||
.options(joinedload(UserFile.user)) # Eager load the user relationship
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
)
|
||||
return {
|
||||
str(user_file.id): DocumentAccess.build(
|
||||
user_emails=[user_file.user.email] if user_file.user else [],
|
||||
user_groups=[],
|
||||
is_public=True if user_file.user is None else False,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
for user_file in user_files
|
||||
}
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BasicState,
|
||||
input=BasicInput,
|
||||
output=BasicOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="prepare_tool_input",
|
||||
action=prepare_tool_input,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="choose_tool",
|
||||
action=choose_tool,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
|
||||
|
||||
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def should_continue(state: BasicState) -> str:
|
||||
return (
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state.tool_choice is None
|
||||
else "call_tool"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = BasicInput(unused=True)
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
config, _ = get_test_config(
|
||||
db_session=db_session,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_request=SearchRequest(query="How does onyx use FastAPI?"),
|
||||
)
|
||||
compiled_graph.invoke(input, config={"metadata": {"config": config}})
|
||||
@@ -1,35 +0,0 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
|
||||
# States contain values that change over the course of graph execution,
|
||||
# Config is for values that are set at the start and never change.
|
||||
# If you are using a value from the config and realize it needs to change,
|
||||
# you should add it to the state and use/update the version in the state.
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class BasicInput(BaseModel):
|
||||
# Langgraph needs a nonempty input, but we pass in all static
|
||||
# data through a RunnableConfig.
|
||||
unused: bool = True
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class BasicOutput(TypedDict):
|
||||
tool_call_chunk: AIMessageChunk
|
||||
|
||||
|
||||
## Graph State
|
||||
class BasicState(
|
||||
BasicInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
):
|
||||
pass
|
||||
@@ -1,64 +0,0 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
PassThroughAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
messages: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
if final_search_results and displayed_search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
full_answer = ""
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for message in messages:
|
||||
|
||||
answer_piece = message.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# this is only used for logging, so fine to
|
||||
# just add the string representation
|
||||
answer_piece = str(answer_piece)
|
||||
full_answer += answer_piece
|
||||
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
tool_call_chunk += message # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(message, []):
|
||||
write_custom_event(
|
||||
"basic_response",
|
||||
response_part,
|
||||
writer,
|
||||
)
|
||||
|
||||
logger.debug(f"Full answer: {full_answer}")
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
||||
@@ -10,6 +10,7 @@ class CoreState(BaseModel):
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
current_step_nr: int = 1
|
||||
|
||||
|
||||
class SubgraphCoreState(BaseModel):
|
||||
|
||||
@@ -14,8 +14,6 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
|
||||
@@ -41,7 +39,7 @@ def search_objects(
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
try:
|
||||
instructions = graph_config.inputs.persona.prompts[0].system_prompt
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_1_instructions = extract_section(
|
||||
instructions, "Agent Step 1:", "Agent Step 2:"
|
||||
@@ -139,17 +137,6 @@ def search_objects(
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in search_objects: {e}")
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=" Researching the individual objects for each source type... ",
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return SearchSourcesObjectsUpdate(
|
||||
analysis_objects=object_list,
|
||||
analysis_sources=document_sources,
|
||||
|
||||
@@ -43,7 +43,7 @@ def research_object_source(
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
try:
|
||||
instructions = graph_config.inputs.persona.prompts[0].system_prompt
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_2_instructions = extract_section(
|
||||
instructions, "Agent Step 2:", "Agent Step 3:"
|
||||
|
||||
@@ -9,8 +9,6 @@ from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
ObjectResearchInformationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -23,17 +21,6 @@ def structure_research_by_object(
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=" consolidating the information across source types for each object...",
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
object_source_research_results = state.object_source_research_results
|
||||
|
||||
object_research_information_results: List[Dict[str, str]] = []
|
||||
|
||||
@@ -33,7 +33,7 @@ def consolidate_object_research(
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
instructions = graph_config.inputs.persona.prompts[0].system_prompt
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_4_instructions = extract_section(
|
||||
instructions, "Agent Step 4:", "Agent Step 5:"
|
||||
|
||||
@@ -12,8 +12,6 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
|
||||
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -33,22 +31,11 @@ def consolidate_research(
|
||||
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=" generating the answer\n\n\n",
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
# Populate prompt
|
||||
instructions = graph_config.inputs.persona.prompts[0].system_prompt
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
try:
|
||||
agent_5_instructions = extract_section(
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_retrieval(state: SubQuestionAnsweringInput) -> Send | Hashable:
|
||||
"""
|
||||
LangGraph edge to send a sub-question to the expanded retrieval.
|
||||
"""
|
||||
edge_start_time = datetime.now()
|
||||
|
||||
return Send(
|
||||
"initial_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
base_search=False,
|
||||
sub_question_id=state.question_id,
|
||||
log_messages=[f"{edge_start_time} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -1,137 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.edges import (
|
||||
send_to_expanded_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
|
||||
check_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
|
||||
format_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
|
||||
generate_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
|
||||
ingest_retrieved_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_query_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph sub-graph builder for the initial individual sub-answer generation.
|
||||
"""
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=SubQuestionAnsweringInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
# The sub-graph that executes the expanded retrieval process for a sub-question
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="initial_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
|
||||
# The node that ingests the retrieved documents and puts them into the proper
|
||||
# state keys.
|
||||
graph.add_node(
|
||||
node="ingest_retrieval",
|
||||
action=ingest_retrieved_documents,
|
||||
)
|
||||
|
||||
# The node that generates the sub-answer
|
||||
graph.add_node(
|
||||
node="generate_sub_answer",
|
||||
action=generate_sub_answer,
|
||||
)
|
||||
|
||||
# The node that checks the sub-answer
|
||||
graph.add_node(
|
||||
node="answer_check",
|
||||
action=check_sub_answer,
|
||||
)
|
||||
|
||||
# The node that formats the sub-answer for the following initial answer generation
|
||||
graph.add_node(
|
||||
node="format_answer",
|
||||
action=format_sub_answer,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_retrieval,
|
||||
path_map=["initial_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="initial_sub_question_expanded_retrieval",
|
||||
end_key="ingest_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_retrieval",
|
||||
end_key="generate_sub_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="generate_sub_answer",
|
||||
end_key="answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_check",
|
||||
end_key="format_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
graph_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = SubQuestionAnsweringInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": graph_config}},
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -1,136 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnswerCheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
|
||||
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
|
||||
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def check_sub_answer(
|
||||
state: AnswerQuestionState, config: RunnableConfig
|
||||
) -> SubQuestionAnswerCheckUpdate:
|
||||
"""
|
||||
LangGraph node to check the quality of the sub-answer. The answer
|
||||
is represented as a boolean value.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
if state.answer == UNKNOWN_ANSWER:
|
||||
return SubQuestionAnswerCheckUpdate(
|
||||
answer_quality=False,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="check sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result="unknown answer",
|
||||
)
|
||||
],
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_ANSWER_CHECK_PROMPT.format(
|
||||
question=state.question,
|
||||
base_answer=state.answer,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: BaseMessage | None = None
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
)
|
||||
|
||||
quality_str: str = cast(str, response.content)
|
||||
answer_quality = binary_string_test(
|
||||
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
|
||||
)
|
||||
log_result = f"Answer quality: {quality_str}"
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
answer_quality = True
|
||||
log_result = agent_error.error_result
|
||||
logger.error("LLM Timeout Error - check sub answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
|
||||
answer_quality = True
|
||||
log_result = agent_error.error_result
|
||||
logger.error("LLM Rate Limit Error - check sub answer")
|
||||
|
||||
return SubQuestionAnswerCheckUpdate(
|
||||
answer_quality=answer_quality,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="check sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,30 +0,0 @@
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
|
||||
|
||||
def format_sub_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
|
||||
"""
|
||||
LangGraph node to generate the sub-answer format.
|
||||
"""
|
||||
return AnswerQuestionOutput(
|
||||
answer_results=[
|
||||
SubQuestionAnswerResults(
|
||||
question=state.question,
|
||||
question_id=state.question_id,
|
||||
verified_high_quality=state.answer_quality,
|
||||
answer=state.answer,
|
||||
sub_query_retrieval_results=state.expanded_retrieval_results,
|
||||
verified_reranked_documents=state.verified_reranked_documents,
|
||||
context_documents=state.context_documents,
|
||||
cited_documents=state.cited_documents,
|
||||
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,185 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnswerGenerationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
dedup_sort_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
LLM_ANSWER_ERROR_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def generate_sub_answer(
|
||||
state: AnswerQuestionState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubQuestionAnswerGenerationUpdate:
|
||||
"""
|
||||
LangGraph node to generate a sub-answer.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = state.question
|
||||
state.verified_reranked_documents
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
|
||||
|
||||
context_docs = dedup_sort_inference_section_list(context_docs)
|
||||
|
||||
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
|
||||
graph_config.inputs.persona
|
||||
).contextualized_prompt
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
cited_documents: list = []
|
||||
log_results = "No documents retrieved"
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=answer_str,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
else:
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
msg = build_sub_question_answer_prompt(
|
||||
question=question,
|
||||
original_question=graph_config.inputs.prompt_builder.raw_user_query,
|
||||
docs=context_docs,
|
||||
persona_specification=persona_contextualized_prompt,
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: list[str] = []
|
||||
|
||||
try:
|
||||
response, _ = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
|
||||
lambda: stream_llm_answer(
|
||||
llm=fast_llm,
|
||||
prompt=msg,
|
||||
event_name="sub_answers",
|
||||
writer=writer,
|
||||
agent_answer_level=level,
|
||||
agent_answer_question_num=question_num,
|
||||
agent_answer_type="agent_sub_answer",
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
|
||||
),
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate sub answer")
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate sub answer")
|
||||
|
||||
if agent_error:
|
||||
answer_str = LLM_ANSWER_ERROR_MESSAGE
|
||||
cited_documents = []
|
||||
log_results = (
|
||||
agent_error.error_result
|
||||
or "Sub-answer generation failed due to LLM error"
|
||||
)
|
||||
|
||||
else:
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_documents = [
|
||||
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
|
||||
]
|
||||
log_results = None
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_ANSWER,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
return SubQuestionAnswerGenerationUpdate(
|
||||
answer=answer_str,
|
||||
cited_documents=cited_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="generate sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result=log_results or "",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionRetrievalIngestionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
|
||||
|
||||
def ingest_retrieved_documents(
|
||||
state: ExpandedRetrievalOutput,
|
||||
) -> SubQuestionRetrievalIngestionUpdate:
|
||||
"""
|
||||
LangGraph node to ingest the retrieved documents to format it for the sub-answer.
|
||||
"""
|
||||
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = [AgentChunkRetrievalStats()]
|
||||
|
||||
return SubQuestionRetrievalIngestionUpdate(
|
||||
expanded_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
|
||||
verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
|
||||
context_documents=state.expanded_retrieval_result.context_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.deep_search.main.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
## Update States
|
||||
class SubQuestionAnswerCheckUpdate(LoggerUpdate, BaseModel):
|
||||
answer_quality: bool = False
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class SubQuestionAnswerGenerationUpdate(LoggerUpdate, BaseModel):
|
||||
answer: str = ""
|
||||
log_messages: list[str] = []
|
||||
cited_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
# answer_stat: AnswerStats
|
||||
|
||||
|
||||
class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
|
||||
expanded_retrieval_results: list[QueryRetrievalResult] = []
|
||||
verified_reranked_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class SubQuestionAnsweringInput(SubgraphCoreState):
|
||||
question: str
|
||||
question_id: str
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class AnswerQuestionState(
|
||||
SubQuestionAnsweringInput,
|
||||
SubQuestionAnswerGenerationUpdate,
|
||||
SubQuestionAnswerCheckUpdate,
|
||||
SubQuestionRetrievalIngestionUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class AnswerQuestionOutput(LoggerUpdate, BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
answer_results: Annotated[list[SubQuestionAnswerResults], add] = []
|
||||
@@ -1,50 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: SubQuestionRetrievalState,
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the initial sub-question answering. If there are no sub-questions,
|
||||
we send empty answers to the initial answer generation, and that answer would be generated
|
||||
solely based on the documents retrieved for the original question.
|
||||
"""
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.initial_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
SubQuestionAnsweringInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_num + 1),
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_num, question in enumerate(state.initial_sub_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -1,96 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.generate_initial_answer import (
|
||||
generate_initial_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.validate_initial_answer import (
|
||||
validate_initial_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.graph_builder import (
|
||||
generate_sub_answers_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.graph_builder import (
|
||||
retrieve_orig_question_docs_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_initial_answer_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the initial answer generation.
|
||||
"""
|
||||
graph = StateGraph(
|
||||
state_schema=SubQuestionRetrievalState,
|
||||
input=SubQuestionRetrievalInput,
|
||||
)
|
||||
|
||||
# The sub-graph that generates the initial sub-answers
|
||||
generate_sub_answers = generate_sub_answers_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="generate_sub_answers_subgraph",
|
||||
action=generate_sub_answers,
|
||||
)
|
||||
|
||||
# The sub-graph that retrieves the original question documents. This is run
|
||||
# in parallel with the sub-answer generation process
|
||||
retrieve_orig_question_docs = retrieve_orig_question_docs_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="retrieve_orig_question_docs_subgraph_wrapper",
|
||||
action=retrieve_orig_question_docs,
|
||||
)
|
||||
|
||||
# Node that generates the initial answer using the results of the previous
|
||||
# two sub-graphs
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
|
||||
# Node that validates the initial answer
|
||||
graph.add_node(
|
||||
node="validate_initial_answer",
|
||||
action=validate_initial_answer,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="retrieve_orig_question_docs_subgraph_wrapper",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="generate_sub_answers_subgraph",
|
||||
)
|
||||
|
||||
# Wait for both, the original question docs and the sub-answers to be generated before proceeding
|
||||
graph.add_edge(
|
||||
start_key=[
|
||||
"retrieve_orig_question_docs_subgraph_wrapper",
|
||||
"generate_sub_answers_subgraph",
|
||||
],
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key="validate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="validate_initial_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
@@ -1,405 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search.main.operations import (
|
||||
calculate_initial_agent_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
get_prompt_enrichment_components,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_deduplicated_structured_subquestion_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The initial answer could not be generated.",
|
||||
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
|
||||
general_error="General LLM Error. The initial answer could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def generate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InitialAnswerUpdate:
|
||||
"""
|
||||
LangGraph node to generate the initial answer, using the initial sub-questions/sub-answers and the
|
||||
documents retrieved for the original question.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
|
||||
|
||||
# get all documents cited in sub-questions
|
||||
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
|
||||
state.sub_question_results
|
||||
)
|
||||
|
||||
orig_question_retrieval_documents = state.orig_question_retrieved_documents
|
||||
|
||||
consolidated_context_docs = structured_subquestion_docs.cited_documents
|
||||
counter = 0
|
||||
for original_doc in orig_question_retrieval_documents:
|
||||
if original_doc in structured_subquestion_docs.cited_documents:
|
||||
continue
|
||||
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
):
|
||||
consolidated_context_docs.append(original_doc)
|
||||
counter += 1
|
||||
|
||||
# sort docs by their scores - though the scores refer to different questions
|
||||
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
|
||||
|
||||
sub_questions: list[str] = []
|
||||
|
||||
# Create the list of documents to stream out. Start with the
|
||||
# ones that wil be in the context (or, if len == 0, use docs
|
||||
# that were retrieved for the original question)
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=relevant_docs,
|
||||
context_documents=structured_subquestion_docs.context_documents,
|
||||
original_question_docs=orig_question_retrieval_documents,
|
||||
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER,
|
||||
)
|
||||
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
||||
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
|
||||
relevance_list = relevance_from_docs(
|
||||
answer_generation_documents.streaming_documents
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_num=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if len(answer_generation_documents.context_documents) == 0:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=UNKNOWN_ANSWER,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
answer = UNKNOWN_ANSWER
|
||||
initial_agent_stats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
else:
|
||||
sub_question_answer_results = state.sub_question_results
|
||||
|
||||
# Collect the sub-questions and sub-answers and construct an appropriate
|
||||
# prompt string.
|
||||
# Consider replacing by a function.
|
||||
answered_sub_questions: list[str] = []
|
||||
all_sub_questions: list[str] = [] # Separate list for tracking all questions
|
||||
|
||||
for idx, sub_question_answer_result in enumerate(
|
||||
sub_question_answer_results, start=1
|
||||
):
|
||||
all_sub_questions.append(sub_question_answer_result.question)
|
||||
|
||||
is_valid_answer = (
|
||||
sub_question_answer_result.verified_high_quality
|
||||
and sub_question_answer_result.answer
|
||||
and sub_question_answer_result.answer != UNKNOWN_ANSWER
|
||||
)
|
||||
|
||||
if is_valid_answer:
|
||||
answered_sub_questions.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=sub_question_answer_result.question,
|
||||
sub_answer=sub_question_answer_result.answer,
|
||||
sub_question_num=idx,
|
||||
)
|
||||
)
|
||||
|
||||
sub_question_answer_str = (
|
||||
"\n\n------\n\n".join(answered_sub_questions)
|
||||
if answered_sub_questions
|
||||
else ""
|
||||
)
|
||||
|
||||
# Use the appropriate prompt based on whether there are sub-questions.
|
||||
base_prompt = (
|
||||
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
if answered_sub_questions
|
||||
else INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
sub_questions = all_sub_questions # Replace the original assignment
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
|
||||
doc_context = format_docs(answer_generation_documents.context_documents)
|
||||
doc_context = trim_prompt_piece(
|
||||
config=model.config,
|
||||
prompt_piece=doc_context,
|
||||
reserved_str=(
|
||||
base_prompt
|
||||
+ sub_question_answer_str
|
||||
+ prompt_enrichment_components.persona_prompts.contextualized_prompt
|
||||
+ prompt_enrichment_components.history
|
||||
+ prompt_enrichment_components.date_str
|
||||
),
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=doc_context,
|
||||
persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt,
|
||||
history=prompt_enrichment_components.history,
|
||||
date_prompt=prompt_enrichment_components.date_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
try:
|
||||
streamed_tokens, dispatch_timings = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
lambda: stream_llm_answer(
|
||||
llm=model,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=(
|
||||
AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate initial answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate initial answer")
|
||||
|
||||
if agent_error:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
StreamingError(
|
||||
error=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=None,
|
||||
answer_error=AgentErrorLog(
|
||||
error_message=agent_error.error_message or "An LLM error occurred",
|
||||
error_type=agent_error.error_type,
|
||||
error_result=agent_error.error_result,
|
||||
),
|
||||
initial_agent_stats=None,
|
||||
generated_sub_questions=sub_questions,
|
||||
agent_base_end_time=None,
|
||||
agent_base_metrics=None,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="generate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result=agent_error.error_result or "An LLM error occurred",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
|
||||
dispatch_main_answer_stop_info(0, writer)
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
initial_agent_stats = calculate_initial_agent_stats(
|
||||
state.sub_question_results, state.orig_question_retrieval_stats
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
|
||||
)
|
||||
|
||||
if initial_agent_stats:
|
||||
logger.debug(initial_agent_stats.original_question)
|
||||
logger.debug(initial_agent_stats.sub_questions)
|
||||
logger.debug(initial_agent_stats.agent_effectiveness)
|
||||
|
||||
agent_base_end_time = datetime.now()
|
||||
|
||||
if agent_base_end_time and state.agent_start_time:
|
||||
duration_s = (agent_base_end_time - state.agent_start_time).total_seconds()
|
||||
else:
|
||||
duration_s = None
|
||||
|
||||
agent_base_metrics = AgentBaseMetrics(
|
||||
num_verified_documents_total=len(relevant_docs),
|
||||
num_verified_documents_core=state.orig_question_retrieval_stats.verified_count,
|
||||
verified_avg_score_core=state.orig_question_retrieval_stats.verified_avg_scores,
|
||||
num_verified_documents_base=initial_agent_stats.sub_questions.get(
|
||||
"num_verified_documents"
|
||||
),
|
||||
verified_avg_score_base=initial_agent_stats.sub_questions.get(
|
||||
"verified_avg_score"
|
||||
),
|
||||
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio"
|
||||
),
|
||||
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"support_ratio"
|
||||
),
|
||||
duration_s=duration_s,
|
||||
)
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=initial_agent_stats,
|
||||
generated_sub_questions=sub_questions,
|
||||
agent_base_end_time=agent_base_end_time,
|
||||
agent_base_metrics=agent_base_metrics,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="generate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,42 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerQualityUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def validate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
) -> InitialAnswerQualityUpdate:
|
||||
"""
|
||||
Check whether the initial answer sufficiently addresses the original user question.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
|
||||
)
|
||||
|
||||
verdict = True # not actually required as already streamed out. Refinement will do similar
|
||||
|
||||
return InitialAnswerQualityUpdate(
|
||||
initial_answer_quality_eval=verdict,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="validate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,51 +0,0 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
ExploratorySearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerQualityUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
OrigQuestionRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
|
||||
QuestionRetrievalResult,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
class SubQuestionRetrievalInput(CoreState):
|
||||
exploratory_search_results: list[InferenceSection]
|
||||
|
||||
|
||||
## Graph State
|
||||
class SubQuestionRetrievalState(
|
||||
# This includes the core state
|
||||
SubQuestionRetrievalInput,
|
||||
InitialQuestionDecompositionUpdate,
|
||||
InitialAnswerUpdate,
|
||||
SubQuestionResultsUpdate,
|
||||
OrigQuestionRetrievalUpdate,
|
||||
InitialAnswerQualityUpdate,
|
||||
ExploratorySearchUpdate,
|
||||
):
|
||||
base_raw_search_result: Annotated[list[QuestionRetrievalResult], add]
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class SubQuestionRetrievalOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
@@ -1,48 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: SubQuestionRetrievalState,
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the initial sub-question answering.
|
||||
"""
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.initial_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_sub_question_subgraphs",
|
||||
SubQuestionAnsweringInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_num + 1),
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_num, question in enumerate(state.initial_sub_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -1,81 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.graph_builder import (
|
||||
answer_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.edges import (
|
||||
parallelize_initial_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.decompose_orig_question import (
|
||||
decompose_orig_question,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.format_initial_sub_answers import (
|
||||
format_initial_sub_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
|
||||
SubQuestionAnsweringState,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def generate_sub_answers_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the initial sub-answer generation process.
|
||||
It generates the initial sub-questions and produces the answers.
|
||||
"""
|
||||
|
||||
graph = StateGraph(
|
||||
state_schema=SubQuestionAnsweringState,
|
||||
input=SubQuestionAnsweringInput,
|
||||
)
|
||||
|
||||
# Decompose the original question into sub-questions
|
||||
graph.add_node(
|
||||
node="decompose_orig_question",
|
||||
action=decompose_orig_question,
|
||||
)
|
||||
|
||||
# The sub-graph that executes the initial sub-question answering for
|
||||
# each of the sub-questions.
|
||||
answer_sub_question_subgraphs = answer_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_sub_question_subgraphs",
|
||||
action=answer_sub_question_subgraphs,
|
||||
)
|
||||
|
||||
# Node that collects and formats the initial sub-question answers
|
||||
graph.add_node(
|
||||
node="format_initial_sub_question_answers",
|
||||
action=format_initial_sub_answers,
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="decompose_orig_question",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="decompose_orig_question",
|
||||
path=parallelize_initial_sub_question_answering,
|
||||
path_map=["answer_sub_question_subgraphs"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key=["answer_sub_question_subgraphs"],
|
||||
end_key="format_initial_sub_question_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="format_initial_sub_question_answers",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
@@ -1,190 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import dispatch_subquestion
|
||||
from onyx.agents.agent_search.deep_search.main.operations import (
|
||||
dispatch_subquestion_sep,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. Sub-questions could not be generated.",
|
||||
rate_limit="LLM Rate Limit Error. Sub-questions could not be generated.",
|
||||
general_error="General LLM Error. Sub-questions could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def decompose_orig_question(
|
||||
state: SubQuestionRetrievalState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InitialQuestionDecompositionUpdate:
|
||||
"""
|
||||
LangGraph node to decompose the original question into sub-questions.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
perform_initial_search_decomposition = (
|
||||
graph_config.behavior.perform_initial_search_decomposition
|
||||
)
|
||||
# Get the rewritten queries in a defined format
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
history = build_history_prompt(graph_config, question)
|
||||
|
||||
# Use the initial search results to inform the decomposition
|
||||
agent_start_time = datetime.now()
|
||||
|
||||
# Initial search to inform decomposition. Just get top 3 fits
|
||||
|
||||
if perform_initial_search_decomposition:
|
||||
# Due to unfortunate state representation in LangGraph, we need here to double check that the retrieval has
|
||||
# happened prior to this point, allowing silent failure here since it is not critical for decomposition in
|
||||
# all queries.
|
||||
if not state.exploratory_search_results:
|
||||
logger.error("Initial search for decomposition failed")
|
||||
|
||||
sample_doc_str = "\n\n".join(
|
||||
[
|
||||
doc.combined_content
|
||||
for doc in state.exploratory_search_results[
|
||||
:AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
)
|
||||
|
||||
else:
|
||||
decomposition_prompt = (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT.format(
|
||||
question=question, history=history
|
||||
)
|
||||
)
|
||||
|
||||
# Start decomposition
|
||||
|
||||
msg = [HumanMessage(content=decomposition_prompt)]
|
||||
|
||||
# Send the initial question as a subquestion with number 0
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=question,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
# dispatches custom events for subquestion tokens, adding in subquestion ids.
|
||||
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
)
|
||||
|
||||
decomposition_response = merge_content(*streamed_tokens)
|
||||
|
||||
list_of_subqs = cast(str, decomposition_response).split("\n")
|
||||
|
||||
initial_sub_questions = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
log_result = f"decomposed original question into {len(initial_sub_questions)} subquestions"
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=0,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError) as e:
|
||||
logger.error("LLM Timeout Error - decompose orig question")
|
||||
raise e # fail loudly on this critical step
|
||||
except LLMRateLimitError as e:
|
||||
logger.error("LLM Rate Limit Error - decompose orig question")
|
||||
raise e
|
||||
|
||||
return InitialQuestionDecompositionUpdate(
|
||||
initial_sub_questions=initial_sub_questions,
|
||||
agent_start_time=agent_start_time,
|
||||
agent_refined_start_time=None,
|
||||
agent_refined_end_time=None,
|
||||
agent_refined_metrics=AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=None,
|
||||
refined_question_boost_factor=None,
|
||||
duration_s=None,
|
||||
),
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate sub answers",
|
||||
node_name="decompose original question",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,50 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
|
||||
|
||||
def format_initial_sub_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> SubQuestionResultsUpdate:
|
||||
"""
|
||||
LangGraph node to format the answers to the initial sub-questions, including
|
||||
deduping verified documents and context documents.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
documents = []
|
||||
context_documents = []
|
||||
cited_documents = []
|
||||
answer_results = state.answer_results
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.verified_reranked_documents)
|
||||
context_documents.extend(answer_result.context_documents)
|
||||
cited_documents.extend(answer_result.cited_documents)
|
||||
|
||||
return SubQuestionResultsUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
verified_reranked_documents=dedup_inference_sections(documents, []),
|
||||
context_documents=dedup_inference_sections(context_documents, []),
|
||||
cited_documents=dedup_inference_sections(cited_documents, []),
|
||||
sub_question_results=answer_results,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate sub answers",
|
||||
node_name="format initial sub answers",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,34 +0,0 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
class SubQuestionAnsweringInput(CoreState):
|
||||
exploratory_search_results: list[InferenceSection]
|
||||
|
||||
|
||||
## Graph State
|
||||
class SubQuestionAnsweringState(
|
||||
# This includes the core state
|
||||
SubQuestionAnsweringInput,
|
||||
InitialQuestionDecompositionUpdate,
|
||||
InitialAnswerUpdate,
|
||||
SubQuestionResultsUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class SubQuestionAnsweringOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user