mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-05 15:02:43 +00:00
Compare commits
58 Commits
cli/v0.2.0
...
jamison/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
670b4f37cb | ||
|
|
79a81f37d5 | ||
|
|
5b8af95007 | ||
|
|
b40935339f | ||
|
|
4a50bfc7ae | ||
|
|
4c9135ecdf | ||
|
|
fe496da134 | ||
|
|
7eb8b335c0 | ||
|
|
183e3b5ec3 | ||
|
|
9c5c42479c | ||
|
|
514f8eedb8 | ||
|
|
eb6bd42c1e | ||
|
|
953cc28625 | ||
|
|
de0f42f6cc | ||
|
|
7ecefdc90f | ||
|
|
21fc013893 | ||
|
|
a1c3a68ba4 | ||
|
|
4fb175ae65 | ||
|
|
800ad326df | ||
|
|
6b920e8a3e | ||
|
|
ef3760796d | ||
|
|
fa5b90df92 | ||
|
|
53953ac4fa | ||
|
|
26bb5c990c | ||
|
|
27b4ed301f | ||
|
|
93ec270ccc | ||
|
|
9e2d6c8a1d | ||
|
|
fc934214d0 | ||
|
|
48fc45a0cd | ||
|
|
009266e53e | ||
|
|
ffb9df7308 | ||
|
|
b0f5e0b8d9 | ||
|
|
43aea5d614 | ||
|
|
593d82f431 | ||
|
|
adf5691b5f | ||
|
|
c1a8a5bd83 | ||
|
|
8fd486da99 | ||
|
|
4bda4d3637 | ||
|
|
13c25eadad | ||
|
|
1f244e6388 | ||
|
|
18b0416d30 | ||
|
|
4bc0bc1efb | ||
|
|
1555217061 | ||
|
|
d177a833f0 | ||
|
|
086997d3c5 | ||
|
|
dccec78397 | ||
|
|
0123133621 | ||
|
|
0b9d154a73 | ||
|
|
6e65e55bf5 | ||
|
|
3f9e208759 | ||
|
|
fb8edda14a | ||
|
|
58decd8a6b | ||
|
|
e97204d9cc | ||
|
|
44ab02c94f | ||
|
|
a98cc30f25 | ||
|
|
a709dcb8fa | ||
|
|
a3dfe6aa1b | ||
|
|
23e4d55fb1 |
287
.github/workflows/deployment.yml
vendored
287
.github/workflows/deployment.yml
vendored
@@ -704,6 +704,9 @@ jobs:
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
SENTRY_RELEASE=${{ github.sha }}
|
||||
secrets: |
|
||||
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
@@ -786,6 +789,9 @@ jobs:
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
SENTRY_RELEASE=${{ github.sha }}
|
||||
secrets: |
|
||||
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
@@ -1503,232 +1509,105 @@ jobs:
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
|
||||
trivy-scan-web:
|
||||
trivy-scan:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-web
|
||||
if: needs.merge-web.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:web-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
fi
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
${SCAN_IMAGE}
|
||||
|
||||
trivy-scan-web-cloud:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-web-cloud
|
||||
if: needs.merge-web-cloud.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:web-cloud-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
fi
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
${SCAN_IMAGE}
|
||||
|
||||
trivy-scan-backend:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-backend
|
||||
if: needs.merge-backend.result == 'success'
|
||||
- merge-model-server
|
||||
if: >-
|
||||
always() && !cancelled() &&
|
||||
(needs.merge-web.result == 'success' ||
|
||||
needs.merge-web-cloud.result == 'success' ||
|
||||
needs.merge-backend.result == 'success' ||
|
||||
needs.merge-model-server.result == 'success')
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-backend
|
||||
- run-id=${{ github.run_id }}-trivy-scan-${{ matrix.component }}
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
permissions:
|
||||
security-events: write # needed for SARIF uploads
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- component: web
|
||||
registry-image: onyxdotapp/onyx-web-server
|
||||
- component: web-cloud
|
||||
registry-image: onyxdotapp/onyx-web-server-cloud
|
||||
- component: backend
|
||||
registry-image: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
trivyignore: backend/.trivyignore
|
||||
- component: model-server
|
||||
registry-image: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- name: Check if this scan should run
|
||||
id: should-run
|
||||
run: |
|
||||
case "$COMPONENT" in
|
||||
web) RESULT="$MERGE_WEB" ;;
|
||||
web-cloud) RESULT="$MERGE_WEB_CLOUD" ;;
|
||||
backend) RESULT="$MERGE_BACKEND" ;;
|
||||
model-server) RESULT="$MERGE_MODEL_SERVER" ;;
|
||||
esac
|
||||
if [ "$RESULT" == "success" ]; then
|
||||
echo "run=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "run=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
env:
|
||||
COMPONENT: ${{ matrix.component }}
|
||||
MERGE_WEB: ${{ needs.merge-web.result }}
|
||||
MERGE_WEB_CLOUD: ${{ needs.merge-web-cloud.result }}
|
||||
MERGE_BACKEND: ${{ needs.merge-backend.result }}
|
||||
MERGE_MODEL_SERVER: ${{ needs.merge-model-server.result }}
|
||||
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
if: steps.should-run.outputs.run == 'true'
|
||||
|
||||
- name: Checkout
|
||||
if: steps.should-run.outputs.run == 'true' && matrix.trivyignore != ''
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
- name: Determine scan image
|
||||
if: steps.should-run.outputs.run == 'true'
|
||||
id: scan-image
|
||||
run: |
|
||||
if [ "$IS_TEST_RUN" == "true" ]; then
|
||||
echo "image=${RUNS_ON_ECR_CACHE}:${TAG_PREFIX}-${SANITIZED_TAG}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "image=docker.io/${REGISTRY_IMAGE}:${REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
env:
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
TAG_PREFIX: ${{ matrix.component }}
|
||||
SANITIZED_TAG: ${{ needs.determine-builds.outputs.sanitized-tag }}
|
||||
REGISTRY_IMAGE: ${{ matrix.registry-image }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
if: steps.should-run.outputs.run == 'true'
|
||||
uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # ratchet:aquasecurity/trivy-action@v0.35.0
|
||||
with:
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
fi
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
--ignorefile /tmp/.trivyignore \
|
||||
${SCAN_IMAGE}
|
||||
image-ref: ${{ steps.scan-image.outputs.image }}
|
||||
severity: CRITICAL,HIGH
|
||||
format: "sarif"
|
||||
output: "trivy-results.sarif"
|
||||
trivyignores: ${{ matrix.trivyignore }}
|
||||
env:
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
trivy-scan-model-server:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-model-server
|
||||
if: needs.merge-model-server.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
- name: Upload Trivy scan results to GitHub Security tab
|
||||
if: steps.should-run.outputs.run == 'true'
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
fi
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
${SCAN_IMAGE}
|
||||
sarif_file: "trivy-results.sarif"
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs:
|
||||
|
||||
@@ -35,6 +35,7 @@ jobs:
|
||||
needs: [provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
environment: ci-protected
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
@@ -183,6 +183,7 @@ jobs:
|
||||
- cherry-pick-to-latest-release
|
||||
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' && needs.cherry-pick-to-latest-release.result == 'success'
|
||||
runs-on: ubuntu-slim
|
||||
environment: ci-protected
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -232,6 +233,7 @@ jobs:
|
||||
- cherry-pick-to-latest-release
|
||||
if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure')
|
||||
runs-on: ubuntu-slim
|
||||
environment: ci-protected
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
- name: Cache Cargo registry and build
|
||||
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # zizmor: ignore[cache-poisoning]
|
||||
uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # zizmor: ignore[cache-poisoning]
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/bin/
|
||||
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
|
||||
uses: helm/chart-testing-action@2e2940618cb426dce2999631d543b53cdcfc8527
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -284,7 +284,7 @@ jobs:
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
|
||||
@@ -626,7 +626,7 @@ jobs:
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
|
||||
|
||||
2
.github/workflows/pr-python-checks.yml
vendored
2
.github/workflows/pr-python-checks.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
|
||||
- name: Cache mypy cache
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: .mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
|
||||
|
||||
174
.github/workflows/pr-python-connector-tests.yml
vendored
174
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -22,132 +22,40 @@ on:
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
permissions:
|
||||
id-token: write # Required for OIDC-based AWS credential exchange
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# AWS
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Cloudflare R2
|
||||
PYTHONPATH: ./backend
|
||||
DISABLE_TELEMETRY: "true"
|
||||
R2_ACCOUNT_ID_DAILY_CONNECTOR_TESTS: ${{ vars.R2_ACCOUNT_ID_DAILY_CONNECTOR_TESTS }}
|
||||
R2_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.R2_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
R2_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.R2_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Google Cloud Storage
|
||||
GCS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.GCS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
GCS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.GCS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ vars.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
# Jira
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
|
||||
# Gong
|
||||
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
|
||||
GONG_ACCESS_KEY_SECRET: ${{ secrets.GONG_ACCESS_KEY_SECRET }}
|
||||
|
||||
# Google
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
# Zendesk
|
||||
ZENDESK_SUBDOMAIN: ${{ secrets.ZENDESK_SUBDOMAIN }}
|
||||
ZENDESK_EMAIL: ${{ secrets.ZENDESK_EMAIL }}
|
||||
ZENDESK_TOKEN: ${{ secrets.ZENDESK_TOKEN }}
|
||||
|
||||
# Salesforce
|
||||
SF_USERNAME: ${{ vars.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
|
||||
# Hubspot
|
||||
HUBSPOT_ACCESS_TOKEN: ${{ secrets.HUBSPOT_ACCESS_TOKEN }}
|
||||
|
||||
# IMAP
|
||||
IMAP_HOST: ${{ vars.IMAP_HOST }}
|
||||
IMAP_USERNAME: ${{ vars.IMAP_USERNAME }}
|
||||
IMAP_PASSWORD: ${{ secrets.IMAP_PASSWORD }}
|
||||
IMAP_MAILBOXES: ${{ vars.IMAP_MAILBOXES }}
|
||||
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ vars.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ vars.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ vars.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ vars.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
|
||||
# Gitlab
|
||||
GITLAB_ACCESS_TOKEN: ${{ secrets.GITLAB_ACCESS_TOKEN }}
|
||||
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
|
||||
# Notion
|
||||
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
|
||||
|
||||
# Highspot
|
||||
HIGHSPOT_KEY: ${{ secrets.HIGHSPOT_KEY }}
|
||||
HIGHSPOT_SECRET: ${{ secrets.HIGHSPOT_SECRET }}
|
||||
|
||||
# Slack
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
|
||||
# Discord
|
||||
DISCORD_CONNECTOR_BOT_TOKEN: ${{ secrets.DISCORD_CONNECTOR_BOT_TOKEN }}
|
||||
|
||||
# Teams
|
||||
TEAMS_APPLICATION_ID: ${{ secrets.TEAMS_APPLICATION_ID }}
|
||||
TEAMS_DIRECTORY_ID: ${{ secrets.TEAMS_DIRECTORY_ID }}
|
||||
TEAMS_SECRET: ${{ secrets.TEAMS_SECRET }}
|
||||
|
||||
# Bitbucket
|
||||
BITBUCKET_WORKSPACE: ${{ secrets.BITBUCKET_WORKSPACE }}
|
||||
BITBUCKET_REPOSITORIES: ${{ secrets.BITBUCKET_REPOSITORIES }}
|
||||
BITBUCKET_PROJECTS: ${{ secrets.BITBUCKET_PROJECTS }}
|
||||
BITBUCKET_EMAIL: ${{ vars.BITBUCKET_EMAIL }}
|
||||
BITBUCKET_API_TOKEN: ${{ secrets.BITBUCKET_API_TOKEN }}
|
||||
|
||||
# Fireflies
|
||||
FIREFLIES_API_KEY: ${{ secrets.FIREFLIES_API_KEY }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-connectors-check", "extras=s3-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=8cpu-linux-x64,
|
||||
"run-id=${{ github.run_id }}-connectors-check",
|
||||
"extras=s3-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
DISABLE_TELEMETRY: "true"
|
||||
environment: ci-protected
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -188,6 +96,66 @@ jobs:
|
||||
- 'backend/onyx/file_processing/**'
|
||||
- 'uv.lock'
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get connector test secrets from AWS Secrets Manager
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2
|
||||
with:
|
||||
parse-json-secrets: false
|
||||
secret-ids: |
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS, test/aws-access-key-id
|
||||
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS, test/aws-secret-access-key
|
||||
R2_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS, test/r2-access-key-id
|
||||
R2_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS, test/r2-secret-access-key
|
||||
GCS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS, test/gcs-access-key-id
|
||||
GCS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS, test/gcs-secret-access-key
|
||||
CONFLUENCE_ACCESS_TOKEN, test/confluence-access-token
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED, test/confluence-access-token-scoped
|
||||
JIRA_BASE_URL, test/jira-base-url
|
||||
JIRA_USER_EMAIL, test/jira-user-email
|
||||
JIRA_API_TOKEN, test/jira-api-token
|
||||
JIRA_API_TOKEN_SCOPED, test/jira-api-token-scoped
|
||||
GONG_ACCESS_KEY, test/gong-access-key
|
||||
GONG_ACCESS_KEY_SECRET, test/gong-access-key-secret
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR, test/google-drive-service-account-json
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1, test/google-drive-oauth-creds-test-user-1
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR, test/google-drive-oauth-creds
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR, test/google-gmail-service-account-json
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR, test/google-gmail-oauth-creds
|
||||
SLAB_BOT_TOKEN, test/slab-bot-token
|
||||
ZENDESK_SUBDOMAIN, test/zendesk-subdomain
|
||||
ZENDESK_EMAIL, test/zendesk-email
|
||||
ZENDESK_TOKEN, test/zendesk-token
|
||||
SF_PASSWORD, test/sf-password
|
||||
SF_SECURITY_TOKEN, test/sf-security-token
|
||||
HUBSPOT_ACCESS_TOKEN, test/hubspot-access-token
|
||||
IMAP_PASSWORD, test/imap-password
|
||||
AIRTABLE_ACCESS_TOKEN, test/airtable-access-token
|
||||
SHAREPOINT_CLIENT_SECRET, test/sharepoint-client-secret
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID, test/perm-sync-sharepoint-client-id
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY, test/perm-sync-sharepoint-private-key
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD, test/perm-sync-sharepoint-cert-password
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID, test/perm-sync-sharepoint-directory-id
|
||||
ACCESS_TOKEN_GITHUB, test/github-access-token
|
||||
GITLAB_ACCESS_TOKEN, test/gitlab-access-token
|
||||
GITBOOK_SPACE_ID, test/gitbook-space-id
|
||||
GITBOOK_API_KEY, test/gitbook-api-key
|
||||
NOTION_INTEGRATION_TOKEN, test/notion-integration-token
|
||||
HIGHSPOT_KEY, test/highspot-key
|
||||
HIGHSPOT_SECRET, test/highspot-secret
|
||||
SLACK_BOT_TOKEN, test/slack-bot-token
|
||||
DISCORD_CONNECTOR_BOT_TOKEN, test/discord-bot-token
|
||||
TEAMS_APPLICATION_ID, test/teams-application-id
|
||||
TEAMS_DIRECTORY_ID, test/teams-directory-id
|
||||
TEAMS_SECRET, test/teams-secret
|
||||
BITBUCKET_WORKSPACE, test/bitbucket-workspace
|
||||
BITBUCKET_API_TOKEN, test/bitbucket-api-token
|
||||
FIREFLIES_API_KEY, test/fireflies-api-key
|
||||
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, GitHub, and Coda)
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
|
||||
1
.github/workflows/pr-python-model-tests.yml
vendored
1
.github/workflows/pr-python-model-tests.yml
vendored
@@ -31,6 +31,7 @@ jobs:
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-model-check"
|
||||
- "extras=ecr-cache"
|
||||
environment: ci-protected
|
||||
timeout-minutes: 45
|
||||
|
||||
env:
|
||||
|
||||
2
.github/workflows/storybook-deploy.yml
vendored
2
.github/workflows/storybook-deploy.yml
vendored
@@ -25,6 +25,7 @@ permissions:
|
||||
jobs:
|
||||
Deploy-Storybook:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ci-protected
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
@@ -54,6 +55,7 @@ jobs:
|
||||
needs: Deploy-Storybook
|
||||
if: always() && needs.Deploy-Storybook.result == 'failure'
|
||||
runs-on: ubuntu-latest
|
||||
environment: ci-protected
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
|
||||
1
.github/workflows/sync_foss.yml
vendored
1
.github/workflows/sync_foss.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
jobs:
|
||||
sync-foss:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ci-protected
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
1
.github/workflows/tag-nightly.yml
vendored
1
.github/workflows/tag-nightly.yml
vendored
@@ -11,6 +11,7 @@ permissions:
|
||||
jobs:
|
||||
create-and-push-tag:
|
||||
runs-on: ubuntu-slim
|
||||
environment: ci-protected
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
|
||||
@@ -6,7 +6,7 @@ Use explicit type annotations for variables to enhance code clarity, especially
|
||||
|
||||
## Best Practices
|
||||
|
||||
Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly.
|
||||
Use the "Engineering Best Practices" section of `CONTRIBUTING.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly.
|
||||
|
||||
## TODOs
|
||||
|
||||
@@ -27,6 +27,7 @@ Code changes must consider both multi-tenant and single-tenant deployments. In m
|
||||
## Nginx Routing — New Backend Routes
|
||||
|
||||
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
|
||||
|
||||
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
|
||||
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
|
||||
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
|
||||
@@ -37,3 +38,7 @@ Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.
|
||||
## Full vs Lite Deployments
|
||||
|
||||
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.
|
||||
|
||||
## SWR Cache Keys — Always Use SWR_KEYS Registry
|
||||
|
||||
All `useSWR()` calls and `mutate()` calls in the frontend must reference the centralized `SWR_KEYS` registry in `web/src/lib/swr-keys.ts` instead of inline endpoint strings or local string constants. Never write `useSWR("/api/some/endpoint", ...)` or `mutate("/api/some/endpoint")` — always use the corresponding `SWR_KEYS.someEndpoint` constant. If the endpoint does not yet exist in the registry, add it there first. This applies to all variants of an endpoint (e.g. query-string variants like `?get_editable=true` must also be registered as their own key).
|
||||
|
||||
@@ -357,5 +357,5 @@ raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.respon
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
to the codebase can be found at `contributing_guides/best_practices.md`.
|
||||
Understand its contents and follow them.
|
||||
to the codebase can be found in the "Engineering Best Practices" section of
|
||||
`CONTRIBUTING.md`. Understand its contents and follow them.
|
||||
|
||||
481
CONTRIBUTING.md
481
CONTRIBUTING.md
@@ -1,32 +1,487 @@
|
||||
# Contributing to Onyx
|
||||
|
||||
Hey there! We are so excited that you're interested in Onyx.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Contribution Opportunities](#contribution-opportunities)
|
||||
- [Contribution Process](#contribution-process)
|
||||
- [Development Setup](#development-setup)
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Backend: Python Requirements](#backend-python-requirements)
|
||||
- [Frontend: Node Dependencies](#frontend-node-dependencies)
|
||||
- [Formatting and Linting](#formatting-and-linting)
|
||||
- [Running the Application](#running-the-application)
|
||||
- [VSCode Debugger (Recommended)](#vscode-debugger-recommended)
|
||||
- [Manually Running for Development](#manually-running-for-development)
|
||||
- [Running in Docker](#running-in-docker)
|
||||
- [macOS-Specific Notes](#macos-specific-notes)
|
||||
- [Engineering Best Practices](#engineering-best-practices)
|
||||
- [Principles and Collaboration](#principles-and-collaboration)
|
||||
- [Style and Maintainability](#style-and-maintainability)
|
||||
- [Performance and Correctness](#performance-and-correctness)
|
||||
- [Repository Conventions](#repository-conventions)
|
||||
- [Release Process](#release-process)
|
||||
- [Getting Help](#getting-help)
|
||||
- [Enterprise Edition Contributions](#enterprise-edition-contributions)
|
||||
|
||||
---
|
||||
|
||||
## Contribution Opportunities
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to look for and share contribution ideas.
|
||||
|
||||
If you have your own feature that you would like to build please create an issue and community members can provide feedback and
|
||||
thumb it up if they feel a common need.
|
||||
If you have your own feature that you would like to build, please create an issue and community members can provide feedback and upvote if they feel a common need.
|
||||
|
||||
---
|
||||
|
||||
## Contributing Code
|
||||
Please reference the documents in contributing_guides folder to ensure that the code base is kept to a high standard.
|
||||
1. dev_setup.md (start here): gives you a guide to setting up a local development environment.
|
||||
2. contribution_process.md: how to ensure you are building valuable features that will get reviewed and merged.
|
||||
3. best_practices.md: before asking for reviews, ensure your changes meet the repo code quality standards.
|
||||
## Contribution Process
|
||||
|
||||
To contribute, please follow the
|
||||
["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
|
||||
|
||||
### 1. Get the feature or enhancement approved
|
||||
|
||||
Create a GitHub issue and see if there are upvotes. If you feel the feature is sufficiently value-additive and you would like approval to contribute it to the repo, tag [Yuhong](https://github.com/yuhongsun96) to review.
|
||||
|
||||
If you do not get a response within a week, feel free to email yuhong@onyx.app and include the issue in the message.
|
||||
|
||||
Not all small features and enhancements will be accepted as there is a balance between feature richness and bloat. We strive to provide the best user experience possible so we have to be intentional about what we include in the app.
|
||||
|
||||
### 2. Get the design approved
|
||||
|
||||
The Onyx team will either provide a design doc and PRD for the feature or request one from you, the contributor. The scope and detail of the design will depend on the individual feature.
|
||||
|
||||
### 3. IP attribution for EE contributions
|
||||
|
||||
If you are contributing features to Onyx Enterprise Edition, you are required to sign the [IP Assignment Agreement](contributor_ip_assignment/EE_Contributor_IP_Assignment_Agreement.md).
|
||||
|
||||
### 4. Review and testing
|
||||
|
||||
Your features must pass all tests and all comments must be addressed prior to merging.
|
||||
|
||||
### Implicit agreements
|
||||
|
||||
If we approve an issue, we are promising you the following:
|
||||
- Your work will receive timely attention and we will put aside other important items to ensure you are not blocked.
|
||||
- You will receive necessary coaching on eng quality, system design, etc. to ensure the feature is completed well.
|
||||
- The Onyx team will pull resources and bandwidth from design, PM, and engineering to ensure that you have all the resources to build the feature to the quality required for merging.
|
||||
|
||||
Because this is a large investment from our team, we ask that you:
|
||||
- Thoroughly read all the requirements of the design docs, engineering best practices, and try to minimize overhead for the Onyx team.
|
||||
- Complete the feature in a timely manner to reduce context switching and an ongoing resource pull from the Onyx team.
|
||||
|
||||
---
|
||||
|
||||
## Development Setup
|
||||
|
||||
Onyx being a fully functional app, relies on some external software, specifically:
|
||||
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [OpenSearch](https://opensearch.org/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [MinIO](https://min.io/) (File Store)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to build and run Onyx locally from source with Docker containers providing the above external software.
|
||||
> We believe this combination is easier for development purposes. If you prefer to use pre-built container images, see [Running in Docker](#running-in-docker) below.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- **Python 3.11** — If using a lower version, modifications will have to be made to the code. Higher versions may have library compatibility issues.
|
||||
- **Docker** — Required for running external services (Postgres, OpenSearch, Redis, MinIO).
|
||||
- **Node.js v22** — We recommend using [nvm](https://github.com/nvm-sh/nvm) to manage Node installations.
|
||||
|
||||
### Backend: Python Requirements
|
||||
|
||||
We use [uv](https://docs.astral.sh/uv/) and recommend creating a [virtual environment](https://docs.astral.sh/uv/pip/environments/#using-a-virtual-environment).
|
||||
|
||||
```bash
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
If using PowerShell, the command slightly differs:
|
||||
|
||||
```powershell
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Install the required Python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
```bash
|
||||
uv run playwright install
|
||||
```
|
||||
|
||||
### Frontend: Node Dependencies
|
||||
|
||||
```bash
|
||||
nvm install 22 && nvm use 22
|
||||
node -v # verify your active version
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
#### Backend
|
||||
|
||||
Set up pre-commit hooks (black / reorder-python-imports):
|
||||
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
We also use `mypy` for static type checking. Onyx is fully type-annotated, and we want to keep it that way! To run the mypy checks manually:
|
||||
|
||||
```bash
|
||||
uv run mypy . # from onyx/backend
|
||||
```
|
||||
|
||||
#### Frontend
|
||||
|
||||
We use `prettier` for formatting. The desired version will be installed via `npm i` from the `onyx/web` directory. To run the formatter:
|
||||
|
||||
```bash
|
||||
npx prettier --write . # from onyx/web
|
||||
```
|
||||
|
||||
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail. Re-stage your changes and commit again.
|
||||
|
||||
---
|
||||
|
||||
## Running the Application
|
||||
|
||||
### VSCode Debugger (Recommended)
|
||||
|
||||
We highly recommend using VSCode's debugger for development.
|
||||
|
||||
#### Initial Setup
|
||||
|
||||
1. Copy `.vscode/env_template.txt` to `.vscode/.env`
|
||||
2. Fill in the necessary environment variables in `.vscode/.env`
|
||||
|
||||
#### Using the Debugger
|
||||
|
||||
Before starting, make sure the Docker Daemon is running.
|
||||
|
||||
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
|
||||
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
|
||||
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
|
||||
4. Navigate to http://localhost:3000 in your browser to start using the app
|
||||
5. Set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
> **Note:** "Clear and Restart External Volumes and Containers" will reset your Postgres and OpenSearch (relational-db and index). Only run this if you are okay with wiping your data.
|
||||
|
||||
**Features:**
|
||||
- Hot reload is enabled for the web server and API servers
|
||||
- Python debugging is configured with debugpy
|
||||
- Environment variables are loaded from `.vscode/.env`
|
||||
- Console output is organized in the integrated terminal with labeled tabs
|
||||
|
||||
### Manually Running for Development
|
||||
|
||||
#### Docker containers for external software
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
Navigate to `onyx/deployment/docker_compose`, then start up Postgres/OpenSearch/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to OpenSearch, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
#### Running Onyx locally
|
||||
|
||||
To start the frontend, navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Next, start the model server which runs the local NLP models. Navigate to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
```
|
||||
|
||||
The first time running Onyx, you will need to run the DB migrations for Postgres. After the first time, this is no longer required unless the DB models change.
|
||||
|
||||
Navigate to `onyx/backend` and with the venv active, run:
|
||||
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
Next, start the task queue which orchestrates the background jobs. Still in `onyx/backend`, run:
|
||||
|
||||
```bash
|
||||
python ./scripts/dev_run_background_jobs.py
|
||||
```
|
||||
|
||||
To run the backend API server, navigate back to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
AUTH_TYPE=basic uvicorn onyx.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:AUTH_TYPE='basic'
|
||||
uvicorn onyx.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
|
||||
> **Note:** If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
#### Wrapping up
|
||||
|
||||
You should now have 4 servers running:
|
||||
|
||||
- Web server
|
||||
- Backend API
|
||||
- Model server
|
||||
- Background jobs
|
||||
|
||||
Now, visit http://localhost:3000 in your browser. You should see the Onyx onboarding wizard where you can connect your external LLM provider to Onyx.
|
||||
|
||||
You've successfully set up a local Onyx instance!
|
||||
|
||||
### Running in Docker
|
||||
|
||||
You can run the full Onyx application stack from pre-built images including all external software dependencies.
|
||||
|
||||
Navigate to `onyx/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to http://localhost:3000 to use Onyx.
|
||||
|
||||
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes:
|
||||
|
||||
```bash
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## macOS-Specific Notes
|
||||
|
||||
### Setting up Python
|
||||
|
||||
Ensure [Homebrew](https://brew.sh/) is already set up, then install Python 3.11:
|
||||
|
||||
```bash
|
||||
brew install python@3.11
|
||||
```
|
||||
|
||||
Add Python 3.11 to your path by adding the following line to `~/.zshrc`:
|
||||
|
||||
```
|
||||
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
|
||||
```
|
||||
|
||||
> **Note:** You will need to open a new terminal for the path change above to take effect.
|
||||
|
||||
### Setting up Docker
|
||||
|
||||
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and ensure it is running before continuing with the docker commands.
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
macOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly. After installing pre-commit, run the following command:
|
||||
|
||||
```bash
|
||||
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Engineering Best Practices
|
||||
|
||||
> These are also what we adhere to as a team internally, we love to build in the open and to uplevel our community and each other through being transparent.
|
||||
|
||||
### Principles and Collaboration
|
||||
|
||||
- **Use 1-way vs 2-way doors.** For 2-way doors, move faster and iterate. For 1-way doors, be more deliberate.
|
||||
- **Consistency > being "right."** Prefer consistent patterns across the codebase. If something is truly bad, fix it everywhere.
|
||||
- **Fix what you touch (selectively).**
|
||||
- Don't feel obligated to fix every best-practice issue you notice.
|
||||
- Don't introduce new bad practices.
|
||||
- If your change touches code that violates best practices, fix it as part of the change.
|
||||
- **Don't tack features on.** When adding functionality, restructure logically as needed to avoid muddying interfaces and accumulating tech debt.
|
||||
|
||||
### Style and Maintainability
|
||||
|
||||
#### Comments and readability
|
||||
Add clear comments:
|
||||
- At logical boundaries (e.g., interfaces) so the reader doesn't need to dig 10 layers deeper.
|
||||
- Wherever assumptions are made or something non-obvious/unexpected is done.
|
||||
- For complicated flows/functions.
|
||||
- Wherever it saves time (e.g., nontrivial regex patterns).
|
||||
|
||||
#### Errors and exceptions
|
||||
- **Fail loudly** rather than silently skipping work.
|
||||
- Example: raise and let exceptions propagate instead of silently dropping a document.
|
||||
- **Don't overuse `try/except`.**
|
||||
- Put `try/except` at the correct logical level.
|
||||
- Do not mask exceptions unless it is clearly appropriate.
|
||||
|
||||
#### Typing
|
||||
- Everything should be **as strictly typed as possible**.
|
||||
- Use `cast` for annoying/loose-typed interfaces (e.g., results of `run_functions_tuples_in_parallel`).
|
||||
- Only `cast` when the type checker sees `Any` or types are too loose.
|
||||
- Prefer types that are easy to read.
|
||||
- Avoid dense types like `dict[tuple[str, str], list[list[float]]]`.
|
||||
- Prefer domain models, e.g.:
|
||||
- `EmbeddingModel(provider_name, model_name)` as a Pydantic model
|
||||
- `dict[EmbeddingModel, list[EmbeddingVector]]`
|
||||
|
||||
#### State, objects, and boundaries
|
||||
- Keep **clear logical boundaries** for state containers and objects.
|
||||
- A **config** object should never contain things like a `db_session`.
|
||||
- Avoid state containers that are overly nested, or huge + flat (use judgment).
|
||||
- Prefer **composition and functional style** over inheritance/OOP.
|
||||
- Prefer **no mutation** unless there's a strong reason.
|
||||
- State objects should be **intentional and explicit**, ideally nonmutating.
|
||||
- Use interfaces/objects to create clear separation of responsibility.
|
||||
- Prefer simplicity when there's no clear gain.
|
||||
- Avoid overcomplicated mechanisms like semaphores.
|
||||
- Prefer **hash maps (dicts)** over tree structures unless there's a strong reason.
|
||||
|
||||
#### Naming
|
||||
- Name variables carefully and intentionally.
|
||||
- Prefer long, explicit names when undecided.
|
||||
- Avoid single-character variables except for small, self-contained utilities (or not at all).
|
||||
- Keep the same object/name consistent through the call stack and within functions when reasonable.
|
||||
- Good: `for token in tokens:`
|
||||
- Bad: `for msg in tokens:` (if iterating tokens)
|
||||
- Function names should bias toward **long + descriptive** for codebase search.
|
||||
- IntelliSense can miss call sites; search works best with unique names.
|
||||
|
||||
#### Correctness by construction
|
||||
- Prefer self-contained correctness — don't rely on callers to "use it right" if you can make misuse hard.
|
||||
- Avoid redundancies: if a function takes an arg, it shouldn't also take a state object that contains that same arg.
|
||||
- No dead code (unless there's a very good reason).
|
||||
- No commented-out code in main or feature branches (unless there's a very good reason).
|
||||
- No duplicate logic:
|
||||
- Don't copy/paste into branches when shared logic can live above the conditional.
|
||||
- If you're afraid to touch the original, you don't understand it well enough.
|
||||
- LLMs often create subtle duplicate logic — review carefully and remove it.
|
||||
- Avoid "nearly identical" objects that confuse when to use which.
|
||||
- Avoid extremely long functions with chained logic:
|
||||
- Encapsulate steps into helpers for readability, even if not reused.
|
||||
- "Pythonic" multi-step expressions are OK in moderation; don't trade clarity for cleverness.
|
||||
|
||||
### Performance and Correctness
|
||||
|
||||
- Avoid holding resources for extended periods (DB sessions, locks/semaphores).
|
||||
- Validate objects on creation and right before use.
|
||||
- Connector code (data to Onyx documents):
|
||||
- Any in-memory structure that can grow without bound based on input must be periodically size-checked.
|
||||
- If a connector is OOMing (often shows up as "missing celery tasks"), this is a top thing to check retroactively.
|
||||
- Async and event loops:
|
||||
- Never introduce new async/event loop Python code, and try to make existing async code synchronous when possible if it makes sense.
|
||||
- Writing async code without 100% understanding the code and having a concrete reason to do so is likely to introduce bugs and not add any meaningful performance gains.
|
||||
|
||||
### Repository Conventions
|
||||
|
||||
#### Where code lives
|
||||
- Pydantic + data models: `models.py` files.
|
||||
- DB interface functions (excluding lazy loading): `db/` directory.
|
||||
- LLM prompts: `prompts/` directory, roughly mirroring the code layout that uses them.
|
||||
- API routes: `server/` directory.
|
||||
|
||||
#### Pydantic and modeling
|
||||
- Prefer **Pydantic** over dataclasses.
|
||||
- If absolutely required, use `allow_arbitrary_types`.
|
||||
|
||||
#### Data conventions
|
||||
- Prefer explicit `None` over sentinel empty strings (usually; depends on intent).
|
||||
- Prefer explicit identifiers: use string enums instead of integer codes.
|
||||
- Avoid magic numbers (co-location is good when necessary). **Always avoid magic strings.**
|
||||
|
||||
#### Logging
|
||||
- Log messages where they are created.
|
||||
- Don't propagate log messages around just to log them elsewhere.
|
||||
|
||||
#### Encapsulation
|
||||
- Don't use private attributes/methods/properties from other classes/modules.
|
||||
- "Private" is private — respect that boundary.
|
||||
|
||||
#### SQLAlchemy guidance
|
||||
- Lazy loading is often bad at scale, especially across multiple list relationships.
|
||||
- Be careful when accessing SQLAlchemy object attributes:
|
||||
- It can help avoid redundant DB queries,
|
||||
- but it can also fail if accessed outside an active session,
|
||||
- and lazy loading can add hidden DB dependencies to otherwise "simple" functions.
|
||||
- Reference: https://www.reddit.com/r/SQLAlchemy/comments/138f248/joinedload_vs_selectinload/
|
||||
|
||||
#### Trunk-based development and feature flags
|
||||
- **PRs should contain no more than 500 lines of real change.**
|
||||
- **Merge to main frequently.** Avoid long-lived feature branches — they create merge conflicts and integration pain.
|
||||
- **Use feature flags for incremental rollout.**
|
||||
- Large features should be merged in small, shippable increments behind a flag.
|
||||
- This allows continuous integration without exposing incomplete functionality.
|
||||
- **Keep flags short-lived.** Once a feature is fully rolled out, remove the flag and dead code paths promptly.
|
||||
- **Flag at the right level.** Prefer flagging at API/UI entry points rather than deep in business logic.
|
||||
- **Test both flag states.** Ensure the codebase works correctly with the flag on and off.
|
||||
|
||||
#### Miscellaneous
|
||||
- Any TODOs you add in the code must be accompanied by either the name/username of the owner of that TODO, or an issue number for an issue referencing that piece of work.
|
||||
- Avoid module-level logic that runs on import, which leads to import-time side effects. Essentially every piece of meaningful logic should exist within some function that has to be explicitly invoked. Acceptable exceptions may include loading environment variables or setting up loggers.
|
||||
- If you find yourself needing something like this, you may want that logic to exist in a file dedicated for manual execution (contains `if __name__ == "__main__":`) which should not be imported by anything else.
|
||||
- Do not conflate Python scripts you intend to run from the command line (contains `if __name__ == "__main__":`) with modules you intend to import from elsewhere. If for some unlikely reason they have to be the same file, any logic specific to executing the file (including imports) should be contained in the `if __name__ == "__main__":` block.
|
||||
- Generally these executable files exist in `backend/scripts/`.
|
||||
|
||||
---
|
||||
|
||||
## Release Process
|
||||
|
||||
Onyx loosely follows the SemVer versioning standard.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
You can see the containers [here](https://hub.docker.com/search?q=onyx%2F).
|
||||
|
||||
---
|
||||
|
||||
## Getting Help
|
||||
|
||||
## Getting Help 🙋
|
||||
We have support channels and generally interesting discussions on our [Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
See you there!
|
||||
|
||||
---
|
||||
|
||||
## Release Process
|
||||
Onyx loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
You can see the containers [here](https://hub.docker.com/search?q=onyx%2F).
|
||||
## Enterprise Edition Contributions
|
||||
|
||||
If you are contributing features to Onyx Enterprise Edition (code under any `ee/` directory), you are required to sign the [IP Assignment Agreement](contributor_ip_assignment/EE_Contributor_IP_Assignment_Agreement.md) ([PDF version](contributor_ip_assignment/EE_Contributor_IP_Assignment_Agreement.pdf)).
|
||||
|
||||
102
README.md
102
README.md
@@ -4,8 +4,6 @@
|
||||
<a href="https://www.onyx.app/?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">Open Source AI Platform</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord" />
|
||||
@@ -27,82 +25,94 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
# Onyx - The Open Source AI Platform
|
||||
|
||||
**[Onyx](https://www.onyx.app/?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)** is a feature-rich, self-hostable Chat UI that works with any LLM. It is easy to deploy and can run in a completely airgapped environment.
|
||||
**[Onyx](https://www.onyx.app/?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)** is the application layer for LLMs - bringing a feature-rich interface that can be easily hosted by anyone.
|
||||
Onyx enables LLMs through advanced capabilities like RAG, web search, code execution, file creation, deep research and more.
|
||||
|
||||
Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep Research, Connectors to 40+ knowledge sources, and more.
|
||||
Connect your applications with over 50+ indexing based connectors provided out of the box or via MCP.
|
||||
|
||||
> [!TIP]
|
||||
> Run Onyx with one command (or see deployment section below):
|
||||
> Deploy with a single command:
|
||||
> ```
|
||||
> curl -fsSL https://onyx.app/install_onyx.sh | bash
|
||||
> ```
|
||||
|
||||
****
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
## ⭐ Features
|
||||
- **🤖 Custom Agents:** Build AI Agents with unique instructions, knowledge and actions.
|
||||
- **🌍 Web Search:** Browse the web with Google PSE, Exa, and Serper as well as an in-house scraper or Firecrawl.
|
||||
- **🔍 RAG:** Best in class hybrid-search + knowledge graph for uploaded files and ingested documents from connectors.
|
||||
- **🔄 Connectors:** Pull knowledge, metadata, and access information from over 40 applications.
|
||||
- **🔬 Deep Research:** Get in depth answers with an agentic multi-step search.
|
||||
- **▶️ Actions & MCP:** Give AI Agents the ability to interact with external systems.
|
||||
- **💻 Code Interpreter:** Execute code to analyze data, render graphs and create files.
|
||||
|
||||
- **🔍 Agentic RAG:** Get best in class search and answer quality based on hybrid index + AI Agents for information retrieval
|
||||
- Benchmark to release soon!
|
||||
- **🔬 Deep Research:** Get in depth reports with a multi-step research flow.
|
||||
- Top of [leaderboard](https://github.com/onyx-dot-app/onyx_deep_research_bench) as of Feb 2026.
|
||||
- **🤖 Custom Agents:** Build AI Agents with unique instructions, knowledge, and actions.
|
||||
- **🌍 Web Search:** Browse the web to get up to date information.
|
||||
- Supports Serper, Google PSE, Brave, SearXNG, and others.
|
||||
- Comes with an in house web crawler and support for Firecrawl/Exa.
|
||||
- **📄 Artifacts:** Generate documents, graphics, and other downloadable artifacts.
|
||||
- **▶️ Actions & MCP:** Let Onyx agents interact with external applications, comes with flexible Auth options.
|
||||
- **💻 Code Execution:** Execute code in a sandbox to analyze data, render graphs, or modify files.
|
||||
- **🎙️ Voice Mode:** Chat with Onyx via text-to-speech and speech-to-text.
|
||||
- **🎨 Image Generation:** Generate images based on user prompts.
|
||||
- **👥 Collaboration:** Chat sharing, feedback gathering, user management, usage analytics, and more.
|
||||
|
||||
Onyx works with all LLMs (like OpenAI, Anthropic, Gemini, etc.) and self-hosted LLMs (like Ollama, vLLM, etc.)
|
||||
Onyx supports all major LLM providers, both self-hosted (like Ollama, LiteLLM, vLLM, etc.) and proprietary (like Anthropic, OpenAI, Gemini, etc.).
|
||||
|
||||
To learn more about the features, check out our [documentation](https://docs.onyx.app/welcome?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)!
|
||||
To learn more - check out our [docs](https://docs.onyx.app/welcome?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)!
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Deployment Modes
|
||||
|
||||
## 🚀 Deployment
|
||||
Onyx supports deployments in Docker, Kubernetes, Terraform, along with guides for major cloud providers.
|
||||
> Onyx supports deployments in Docker, Kubernetes, Helm/Terraform and provides guides for major cloud providers.
|
||||
> Detailed deployment guides found [here](https://docs.onyx.app/deployment/overview).
|
||||
|
||||
See guides below:
|
||||
- [Docker](https://docs.onyx.app/deployment/local/docker?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme) or [Quickstart](https://docs.onyx.app/deployment/getting_started/quickstart?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme) (best for most users)
|
||||
- [Kubernetes](https://docs.onyx.app/deployment/local/kubernetes?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme) (best for large teams)
|
||||
- [Terraform](https://docs.onyx.app/deployment/local/terraform?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme) (best for teams already using Terraform)
|
||||
- Cloud specific guides (best if specifically using [AWS EKS](https://docs.onyx.app/deployment/cloud/aws/eks?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme), [Azure VMs](https://docs.onyx.app/deployment/cloud/azure?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme), etc.)
|
||||
Onyx supports two separate deployment options: standard and lite.
|
||||
|
||||
#### Onyx Lite
|
||||
|
||||
The Lite mode can be thought of as a lightweight Chat UI. It requires less resources (under 1GB memory) and runs a less complex stack.
|
||||
It is great for users who want to test out Onyx quickly or for teams who are only interested in the Chat UI and Agents functionalities.
|
||||
|
||||
#### Standard Onyx
|
||||
|
||||
The complete feature set of Onyx which is recommended for serious users and larger teams. Additional components not included in Lite mode:
|
||||
- Vector + Keyword index for RAG.
|
||||
- Background containers to run job queues and workers for syncing knowledge from connectors.
|
||||
- AI model inference servers to run deep learning models used during indexing and inference.
|
||||
- Performance optimizations for large scale use via in memory cache (Redis) and blob store (MinIO).
|
||||
|
||||
> [!TIP]
|
||||
> **To try Onyx for free without deploying, check out [Onyx Cloud](https://cloud.onyx.app/signup?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)**.
|
||||
> **To try Onyx for free without deploying, visit [Onyx Cloud](https://cloud.onyx.app/signup?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)**.
|
||||
|
||||
---
|
||||
|
||||
## 🏢 Onyx for Enterprise
|
||||
|
||||
## 🔍 Other Notable Benefits
|
||||
Onyx is built for teams of all sizes, from individual users to the largest global enterprises.
|
||||
|
||||
- **Enterprise Search**: far more than simple RAG, Onyx has custom indexing and retrieval that remains performant and accurate for scales of up to tens of millions of documents.
|
||||
- **Security**: SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- **Management UI**: different user roles such as basic, curator, and admin.
|
||||
- **Document Permissioning**: mirrors user access from external apps for RAG use cases.
|
||||
|
||||
|
||||
|
||||
## 🚧 Roadmap
|
||||
To see ongoing and upcoming projects, check out our [roadmap](https://github.com/orgs/onyx-dot-app/projects/2)!
|
||||
|
||||
|
||||
Onyx is built for teams of all sizes, from individual users to the largest global enterprises:
|
||||
- 👥 Collaboration: Share chats and agents with other members of your organization.
|
||||
- 🔐 Single Sign On: SSO via Google OAuth, OIDC, or SAML. Group syncing and user provisioning via SCIM.
|
||||
- 🛡️ Role Based Access Control: RBAC for sensitive resources like access to agents, actions, etc.
|
||||
- 📊 Analytics: Usage graphs broken down by teams, LLMs, or agents.
|
||||
- 🕵️ Query History: Audit usage to ensure safe adoption of AI in your organization.
|
||||
- 💻 Custom code: Run custom code to remove PII, reject sensitive queries, or to run custom analysis.
|
||||
- 🎨 Whitelabeling: Customize the look and feel of Onyx with custom naming, icons, banners, and more.
|
||||
|
||||
## 📚 Licensing
|
||||
|
||||
There are two editions of Onyx:
|
||||
|
||||
- Onyx Community Edition (CE) is available freely under the MIT license.
|
||||
- Onyx Community Edition (CE) is available freely under the MIT license and covers all of the core features for Chat, RAG, Agents, and Actions.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
|
||||
|
||||
For feature details, check out [our website](https://www.onyx.app/pricing?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme).
|
||||
|
||||
|
||||
|
||||
## 👪 Community
|
||||
|
||||
Join our open source community on **[Discord](https://discord.gg/TDJ59cGV2X)**!
|
||||
|
||||
|
||||
|
||||
## 💡 Contributing
|
||||
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
"""csv to tabular chat file type
|
||||
|
||||
Revision ID: 8188861f4e92
|
||||
Revises: d8cdfee5df80
|
||||
Create Date: 2026-03-31 19:23:05.753184
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8188861f4e92"
|
||||
down_revision = "d8cdfee5df80"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET files = (
|
||||
SELECT jsonb_agg(
|
||||
CASE
|
||||
WHEN elem->>'type' = 'csv'
|
||||
THEN jsonb_set(elem, '{type}', '"tabular"')
|
||||
ELSE elem
|
||||
END
|
||||
)
|
||||
FROM jsonb_array_elements(files) AS elem
|
||||
)
|
||||
WHERE files::text LIKE '%"type": "csv"%'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET files = (
|
||||
SELECT jsonb_agg(
|
||||
CASE
|
||||
WHEN elem->>'type' = 'tabular'
|
||||
THEN jsonb_set(elem, '{type}', '"csv"')
|
||||
ELSE elem
|
||||
END
|
||||
)
|
||||
FROM jsonb_array_elements(files) AS elem
|
||||
)
|
||||
WHERE files::text LIKE '%"type": "tabular"%'
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
"""add skipped to userfilestatus
|
||||
|
||||
Revision ID: d8cdfee5df80
|
||||
Revises: 1d78c0ca7853
|
||||
Create Date: 2026-04-01 10:47:12.593950
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d8cdfee5df80"
|
||||
down_revision = "1d78c0ca7853"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
TABLE = "user_file"
|
||||
COLUMN = "status"
|
||||
CONSTRAINT_NAME = "ck_user_file_status"
|
||||
|
||||
OLD_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
NEW_VALUES = (
|
||||
"PROCESSING",
|
||||
"INDEXING",
|
||||
"COMPLETED",
|
||||
"SKIPPED",
|
||||
"FAILED",
|
||||
"CANCELED",
|
||||
"DELETING",
|
||||
)
|
||||
|
||||
|
||||
def _drop_status_check_constraint() -> None:
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
for constraint in inspector.get_check_constraints(TABLE):
|
||||
if COLUMN in constraint.get("sqltext", ""):
|
||||
constraint_name = constraint["name"]
|
||||
if constraint_name is not None:
|
||||
op.drop_constraint(constraint_name, TABLE, type_="check")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(f"UPDATE {TABLE} SET {COLUMN} = 'COMPLETED' WHERE {COLUMN} = 'SKIPPED'")
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
@@ -5,6 +5,7 @@ from onyx.background.celery.apps.primary import celery_app
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.hooks",
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
|
||||
@@ -55,6 +55,15 @@ ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
"schedule": timedelta(days=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
|
||||
@@ -13,6 +13,7 @@ from redis.lock import Lock as RedisLock
|
||||
from ee.onyx.server.tenants.provisioning import setup_tenant
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
@@ -29,9 +30,10 @@ from shared_configs.configs import TENANT_ID_PREFIX
|
||||
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
|
||||
_MAX_TENANTS_PER_RUN = 5
|
||||
|
||||
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
|
||||
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
|
||||
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -91,8 +93,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
|
||||
if batch_size < tenants_to_provision:
|
||||
task_logger.info(
|
||||
f"Capping batch to {batch_size} "
|
||||
f"(need {tenants_to_provision}, will catch up next cycle)"
|
||||
f"Capping batch to {batch_size} (need {tenants_to_provision}, will catch up next cycle)"
|
||||
)
|
||||
|
||||
provisioned = 0
|
||||
@@ -103,12 +104,14 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
provisioned += 1
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to provision tenant {i + 1}/{batch_size}, "
|
||||
"continuing with remaining tenants"
|
||||
f"Failed to provision tenant {i + 1}/{batch_size}, continuing with remaining tenants"
|
||||
)
|
||||
|
||||
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
|
||||
|
||||
# Migrate any pool tenants that were provisioned before a new migration was deployed
|
||||
_migrate_stale_pool_tenants()
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
@@ -121,6 +124,46 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
)
|
||||
|
||||
|
||||
def _migrate_stale_pool_tenants() -> None:
|
||||
"""
|
||||
Run alembic upgrade head on all pool tenants. Since alembic upgrade head is
|
||||
idempotent, tenants already at head are a fast no-op. This ensures pool
|
||||
tenants are always current so that signup doesn't hit schema mismatches
|
||||
(e.g. missing columns added after the tenant was pre-provisioned).
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
pool_tenants = db_session.query(AvailableTenant).all()
|
||||
tenant_ids = [t.tenant_id for t in pool_tenants]
|
||||
|
||||
if not tenant_ids:
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Checking {len(tenant_ids)} pool tenant(s) for pending migrations"
|
||||
)
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
try:
|
||||
run_alembic_migrations(tenant_id)
|
||||
new_version = get_current_alembic_version(tenant_id)
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.first()
|
||||
)
|
||||
if tenant and tenant.alembic_version != new_version:
|
||||
task_logger.info(
|
||||
f"Migrated pool tenant {tenant_id}: {tenant.alembic_version} -> {new_version}"
|
||||
)
|
||||
tenant.alembic_version = new_version
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to migrate pool tenant {tenant_id}, skipping"
|
||||
)
|
||||
|
||||
|
||||
def pre_provision_tenant() -> bool:
|
||||
"""
|
||||
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||
|
||||
@@ -69,5 +69,7 @@ EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
|
||||
"/admin/token-rate-limits",
|
||||
# Evals
|
||||
"/evals",
|
||||
# Hook extensions
|
||||
"/admin/hooks",
|
||||
}
|
||||
)
|
||||
|
||||
0
backend/ee/onyx/hooks/__init__.py
Normal file
0
backend/ee/onyx/hooks/__init__.py
Normal file
385
backend/ee/onyx/hooks/executor.py
Normal file
385
backend/ee/onyx/hooks/executor.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is a validated Pydantic model instance (response_type)
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _execute_hook_inner(
|
||||
hook: Hook,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSoftFailed:
|
||||
"""Make the HTTP call, validate the response, and return a typed model.
|
||||
|
||||
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
|
||||
"""
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(
|
||||
timeout=timeout, follow_redirects=False
|
||||
) as client: # SSRF guard: never follow redirects
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
|
||||
# Validate the response payload against response_type.
|
||||
# A validation failure downgrades the outcome to a failure so it is logged,
|
||||
# is_reachable is left unchanged (server responded — just a bad payload),
|
||||
# and fail_strategy is respected below.
|
||||
validated_model: T | None = None
|
||||
if outcome.is_success and outcome.response_payload is not None:
|
||||
try:
|
||||
validated_model = response_type.model_validate(outcome.response_payload)
|
||||
except ValidationError as e:
|
||||
msg = (
|
||||
f"Hook response failed validation against {response_type.__name__}: {e}"
|
||||
)
|
||||
outcome = _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=outcome.status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
|
||||
if validated_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"validated_model is None for successful hook call (hook_id={hook_id})",
|
||||
)
|
||||
return validated_model
|
||||
|
||||
|
||||
def _execute_hook_impl(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""EE implementation — loaded by CE's execute_hook via fetch_versioned_implementation.
|
||||
|
||||
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
|
||||
hook failed with SOFT fail strategy, or a validated response model on success.
|
||||
Raises OnyxError on HARD failure or if the hook is misconfigured.
|
||||
"""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
fail_strategy = hook.fail_strategy
|
||||
hook_id = hook.id
|
||||
|
||||
try:
|
||||
return _execute_hook_inner(hook, payload, response_type)
|
||||
except Exception:
|
||||
if fail_strategy == HookFailStrategy.SOFT:
|
||||
logger.exception(
|
||||
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
raise
|
||||
@@ -15,6 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.features.hooks.api import router as hook_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.license_enforcement import (
|
||||
@@ -138,6 +139,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
|
||||
include_router_with_global_prefix_prepended(application, evals_router)
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
|
||||
0
backend/ee/onyx/server/features/__init__.py
Normal file
0
backend/ee/onyx/server/features/__init__.py
Normal file
0
backend/ee/onyx/server/features/hooks/__init__.py
Normal file
0
backend/ee/onyx/server/features/hooks/__init__.py
Normal file
@@ -99,6 +99,26 @@ async def get_or_provision_tenant(
|
||||
tenant_id = await get_available_tenant()
|
||||
|
||||
if tenant_id:
|
||||
# Run migrations to ensure the pre-provisioned tenant schema is current.
|
||||
# Pool tenants may have been created before a new migration was deployed.
|
||||
# Capture as a non-optional local so mypy can type the lambda correctly.
|
||||
_tenant_id: str = tenant_id
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None, lambda: run_alembic_migrations(_tenant_id)
|
||||
)
|
||||
except Exception:
|
||||
# The tenant was already dequeued from the pool — roll it back so
|
||||
# it doesn't end up orphaned (schema exists, but not assigned to anyone).
|
||||
logger.exception(
|
||||
f"Migration failed for pre-provisioned tenant {_tenant_id}; rolling back"
|
||||
)
|
||||
try:
|
||||
await rollback_tenant_provisioning(_tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to rollback orphaned tenant {_tenant_id}")
|
||||
raise
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
|
||||
@@ -100,6 +100,7 @@ def get_model_app() -> FastAPI:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -20,6 +20,7 @@ from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
@@ -65,6 +66,7 @@ if SENTRY_DSN:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
@@ -515,7 +517,8 @@ def reset_tenant_id(
|
||||
|
||||
|
||||
def wait_for_vespa_or_shutdown(
|
||||
sender: Any, **kwargs: Any # noqa: ARG001
|
||||
sender: Any, # noqa: ARG001
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None: # noqa: ARG001
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
@@ -317,7 +317,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.hooks",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
@@ -362,19 +361,6 @@ if not MULTI_TENANT:
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
if HOOKS_AVAILABLE:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
"schedule": timedelta(days=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_cloud_tasks(
|
||||
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
|
||||
|
||||
@@ -9,6 +9,7 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
|
||||
@@ -137,6 +138,7 @@ def _docfetching_task(
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -319,6 +319,11 @@ def monitor_indexing_attempt_progress(
|
||||
)
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
total_batches: int | str = (
|
||||
coordination_status.total_batches
|
||||
if coordination_status.total_batches is not None
|
||||
else "?"
|
||||
)
|
||||
if coordination_status.found:
|
||||
task_logger.info(
|
||||
f"Indexing attempt progress: "
|
||||
@@ -326,7 +331,7 @@ def monitor_indexing_attempt_progress(
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id} "
|
||||
f"completed_batches={coordination_status.completed_batches} "
|
||||
f"total_batches={coordination_status.total_batches or '?'} "
|
||||
f"total_batches={total_batches} "
|
||||
f"total_docs={coordination_status.total_docs} "
|
||||
f"total_failures={coordination_status.total_failures}"
|
||||
f"elapsed={(current_db_time - attempt.time_created).seconds}"
|
||||
@@ -410,7 +415,7 @@ def check_indexing_completion(
|
||||
logger.info(
|
||||
f"Indexing status: "
|
||||
f"indexing_completed={indexing_completed} "
|
||||
f"batches_processed={batches_processed}/{batches_total or '?'} "
|
||||
f"batches_processed={batches_processed}/{batches_total if batches_total is not None else '?'} "
|
||||
f"total_docs={coordination_status.total_docs} "
|
||||
f"total_chunks={coordination_status.total_chunks} "
|
||||
f"total_failures={coordination_status.total_failures}"
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.datastructures import Headers
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import ChatHistoryResult
|
||||
@@ -51,6 +52,60 @@ logger = setup_logger()
|
||||
IMAGE_GENERATION_TOOL_NAME = "generate_image"
|
||||
|
||||
|
||||
class FileContextResult(BaseModel):
|
||||
"""Result of building a file's LLM context representation."""
|
||||
|
||||
message: ChatMessageSimple
|
||||
tool_metadata: FileToolMetadata
|
||||
|
||||
|
||||
def build_file_context(
|
||||
tool_file_id: str,
|
||||
filename: str,
|
||||
file_type: ChatFileType,
|
||||
content_text: str | None = None,
|
||||
token_count: int = 0,
|
||||
approx_char_count: int | None = None,
|
||||
) -> FileContextResult:
|
||||
"""Build the LLM context representation for a single file.
|
||||
|
||||
Centralises how files should appear in the LLM prompt
|
||||
— the ID that FileReaderTool accepts (``UserFile.id`` for user files).
|
||||
"""
|
||||
if file_type.use_metadata_only():
|
||||
message_text = (
|
||||
f"File: {filename} (id={tool_file_id})\n"
|
||||
"Use the file_reader or python tools to access "
|
||||
"this file's contents."
|
||||
)
|
||||
message = ChatMessageSimple(
|
||||
message=message_text,
|
||||
token_count=max(1, len(message_text) // 4),
|
||||
message_type=MessageType.USER,
|
||||
file_id=tool_file_id,
|
||||
)
|
||||
else:
|
||||
message_text = f"File: {filename}\n{content_text or ''}\nEnd of File"
|
||||
message = ChatMessageSimple(
|
||||
message=message_text,
|
||||
token_count=token_count,
|
||||
message_type=MessageType.USER,
|
||||
file_id=tool_file_id,
|
||||
)
|
||||
|
||||
metadata = FileToolMetadata(
|
||||
file_id=tool_file_id,
|
||||
filename=filename,
|
||||
approx_char_count=(
|
||||
approx_char_count
|
||||
if approx_char_count is not None
|
||||
else len(content_text or "")
|
||||
),
|
||||
)
|
||||
|
||||
return FileContextResult(message=message, tool_metadata=metadata)
|
||||
|
||||
|
||||
def create_chat_session_from_request(
|
||||
chat_session_request: ChatSessionCreationRequest,
|
||||
user_id: UUID | None,
|
||||
@@ -538,7 +593,7 @@ def convert_chat_history(
|
||||
for idx, chat_message in enumerate(chat_history):
|
||||
if chat_message.message_type == MessageType.USER:
|
||||
# Process files attached to this message
|
||||
text_files: list[ChatLoadedFile] = []
|
||||
text_files: list[tuple[ChatLoadedFile, FileDescriptor]] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
|
||||
if chat_message.files:
|
||||
@@ -549,34 +604,26 @@ def convert_chat_history(
|
||||
if loaded_file.file_type == ChatFileType.IMAGE:
|
||||
image_files.append(loaded_file)
|
||||
else:
|
||||
# Text files (DOC, PLAIN_TEXT, CSV) are added as separate messages
|
||||
text_files.append(loaded_file)
|
||||
# Text files (DOC, PLAIN_TEXT, TABULAR) are added as separate messages
|
||||
text_files.append((loaded_file, file_descriptor))
|
||||
|
||||
# Add text files as separate messages before the user message.
|
||||
# Each message is tagged with ``file_id`` so that forgotten files
|
||||
# can be detected after context-window truncation.
|
||||
for text_file in text_files:
|
||||
file_text = text_file.content_text or ""
|
||||
filename = text_file.filename
|
||||
message = (
|
||||
f"File: {filename}\n{file_text}\nEnd of File"
|
||||
if filename
|
||||
else file_text
|
||||
)
|
||||
simple_messages.append(
|
||||
ChatMessageSimple(
|
||||
message=message,
|
||||
token_count=text_file.token_count,
|
||||
message_type=MessageType.USER,
|
||||
image_files=None,
|
||||
file_id=text_file.file_id,
|
||||
)
|
||||
)
|
||||
all_injected_file_metadata[text_file.file_id] = FileToolMetadata(
|
||||
file_id=text_file.file_id,
|
||||
filename=filename or "unknown",
|
||||
approx_char_count=len(file_text),
|
||||
for text_file, fd in text_files:
|
||||
# Use user_file_id as the FileReaderTool accepts that.
|
||||
# Fall back to the file-store path id.
|
||||
tool_id = fd.get("user_file_id") or text_file.file_id
|
||||
filename = text_file.filename or "unknown"
|
||||
ctx = build_file_context(
|
||||
tool_file_id=tool_id,
|
||||
filename=filename,
|
||||
file_type=text_file.file_type,
|
||||
content_text=text_file.content_text,
|
||||
token_count=text_file.token_count,
|
||||
)
|
||||
simple_messages.append(ctx.message)
|
||||
all_injected_file_metadata[tool_id] = ctx.tool_metadata
|
||||
|
||||
# Sum token counts from image files (excluding project image files)
|
||||
image_token_count = (
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.cache.interface import CacheBackend
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
from onyx.chat.chat_utils import build_file_context
|
||||
from onyx.chat.chat_utils import convert_chat_history
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.chat_utils import create_chat_session_from_request
|
||||
@@ -90,6 +91,7 @@ from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
@@ -117,6 +119,8 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
class _AvailableFiles(BaseModel):
|
||||
"""Separated file IDs for the FileReaderTool so it knows which loader to use."""
|
||||
@@ -301,16 +305,27 @@ def extract_context_files(
|
||||
if not user_files:
|
||||
return _empty_extracted_context_files()
|
||||
|
||||
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
|
||||
# Aggregate tokens for the file content that will be added
|
||||
# Skip tokens for those with metadata only
|
||||
aggregate_tokens = sum(
|
||||
uf.token_count or 0
|
||||
for uf in user_files
|
||||
if not mime_type_to_chat_file_type(uf.file_type).use_metadata_only()
|
||||
)
|
||||
max_actual_tokens = (
|
||||
llm_max_context_window - reserved_token_count
|
||||
) * max_llm_context_percentage
|
||||
|
||||
if aggregate_tokens >= max_actual_tokens:
|
||||
tool_metadata = []
|
||||
use_as_search_filter = not DISABLE_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB:
|
||||
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
|
||||
overflow_tool_metadata = [_build_tool_metadata(uf) for uf in user_files]
|
||||
else:
|
||||
overflow_tool_metadata = [
|
||||
_build_tool_metadata(uf)
|
||||
for uf in user_files
|
||||
if mime_type_to_chat_file_type(uf.file_type).use_metadata_only()
|
||||
]
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
@@ -318,11 +333,11 @@ def extract_context_files(
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
file_metadata_for_tool=tool_metadata,
|
||||
file_metadata_for_tool=overflow_tool_metadata,
|
||||
)
|
||||
|
||||
# Files fit — load them into context
|
||||
user_file_map = {str(uf.id): uf for uf in user_files}
|
||||
user_file_map = {uf.file_id: uf for uf in user_files}
|
||||
in_memory_files = load_in_memory_chat_files(
|
||||
user_file_ids=[uf.id for uf in user_files],
|
||||
db_session=db_session,
|
||||
@@ -331,23 +346,38 @@ def extract_context_files(
|
||||
file_texts: list[str] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
file_metadata: list[ContextFileMetadata] = []
|
||||
tool_metadata: list[FileToolMetadata] = []
|
||||
total_token_count = 0
|
||||
|
||||
for f in in_memory_files:
|
||||
uf = user_file_map.get(str(f.file_id))
|
||||
if f.file_type.is_text_file():
|
||||
filename = f.filename or f"file_{f.file_id}"
|
||||
|
||||
if f.file_type.use_metadata_only():
|
||||
# Metadata-only files are not injected as full text.
|
||||
# Only the metadata is provided, with LLM using tools
|
||||
if not uf:
|
||||
logger.error(
|
||||
f"File with id={f.file_id} in metadata-only path with no associated user file"
|
||||
)
|
||||
continue
|
||||
tool_metadata.append(_build_tool_metadata(uf))
|
||||
elif f.file_type.is_text_file():
|
||||
text_content = _extract_text_from_in_memory_file(f)
|
||||
if not text_content:
|
||||
continue
|
||||
if not uf:
|
||||
logger.warning(f"No user file for file_id={f.file_id}")
|
||||
continue
|
||||
file_texts.append(text_content)
|
||||
file_metadata.append(
|
||||
ContextFileMetadata(
|
||||
file_id=str(f.file_id),
|
||||
filename=f.filename or f"file_{f.file_id}",
|
||||
file_id=str(uf.id),
|
||||
filename=filename,
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
if uf and uf.token_count:
|
||||
if uf.token_count:
|
||||
total_token_count += uf.token_count
|
||||
elif f.file_type == ChatFileType.IMAGE:
|
||||
token_count = uf.token_count if uf and uf.token_count else 0
|
||||
@@ -370,24 +400,22 @@ def extract_context_files(
|
||||
total_token_count=total_token_count,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
file_metadata_for_tool=tool_metadata,
|
||||
)
|
||||
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
def _build_tool_metadata(user_file: UserFile) -> FileToolMetadata:
|
||||
"""Build lightweight FileToolMetadata from a UserFile record.
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata from a list of UserFile records."""
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in user_files
|
||||
]
|
||||
Delegates to ``build_file_context`` so that the file ID exposed to the
|
||||
LLM is always consistent with what FileReaderTool expects.
|
||||
"""
|
||||
return build_file_context(
|
||||
tool_file_id=str(user_file.id),
|
||||
filename=user_file.name,
|
||||
file_type=mime_type_to_chat_file_type(user_file.file_type),
|
||||
approx_char_count=(user_file.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
).tool_metadata
|
||||
|
||||
|
||||
def determine_search_params(
|
||||
|
||||
@@ -805,6 +805,10 @@ MINI_CHUNK_SIZE = 150
|
||||
# This is the number of regular chunks per large chunk
|
||||
LARGE_CHUNK_RATIO = 4
|
||||
|
||||
# The maximum number of chunks that can be held for 1 document processing batch
|
||||
# The purpose of this is to set an upper bound on memory usage
|
||||
MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000)
|
||||
|
||||
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
@@ -1075,7 +1079,6 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
HOOK_ENABLED = os.environ.get("HOOK_ENABLED", "").lower() == "true"
|
||||
|
||||
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -212,6 +212,7 @@ class DocumentSource(str, Enum):
|
||||
PRODUCTBOARD = "productboard"
|
||||
FILE = "file"
|
||||
CODA = "coda"
|
||||
CANVAS = "canvas"
|
||||
NOTION = "notion"
|
||||
ZULIP = "zulip"
|
||||
LINEAR = "linear"
|
||||
@@ -672,6 +673,7 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
|
||||
DocumentSource.SLAB: "slab data",
|
||||
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
|
||||
DocumentSource.FILE: "files",
|
||||
DocumentSource.CANVAS: "canvas lms - courses, pages, assignments, and announcements",
|
||||
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
|
||||
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
|
||||
project management, and collaboration tools into a single, customizable platform",
|
||||
|
||||
32
backend/onyx/connectors/canvas/access.py
Normal file
32
backend/onyx/connectors/canvas/access.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Permissioning / AccessControl logic for Canvas courses.
|
||||
|
||||
CE stub — returns None (no permissions). The EE implementation is loaded
|
||||
at runtime via ``fetch_versioned_implementation``.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
def get_course_permissions(
|
||||
canvas_client: CanvasApiClient,
|
||||
course_id: int,
|
||||
) -> ExternalAccess | None:
|
||||
if not global_version.is_ee_version():
|
||||
return None
|
||||
|
||||
ee_get_course_permissions = cast(
|
||||
Callable[[CanvasApiClient, int], ExternalAccess | None],
|
||||
fetch_versioned_implementation(
|
||||
"onyx.external_permissions.canvas.access",
|
||||
"get_course_permissions",
|
||||
),
|
||||
)
|
||||
|
||||
return ee_get_course_permissions(canvas_client, course_id)
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -190,3 +191,22 @@ class CanvasApiClient:
|
||||
if clean_endpoint:
|
||||
final_url += "/" + clean_endpoint
|
||||
return final_url
|
||||
|
||||
def paginate(
|
||||
self,
|
||||
endpoint: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Iterator[list[Any]]:
|
||||
"""Yield each page of results, following Link-header pagination.
|
||||
|
||||
Makes the first request with endpoint + params, then follows
|
||||
next_url from Link headers for subsequent pages.
|
||||
"""
|
||||
response, next_url = self.get(endpoint, params=params)
|
||||
while True:
|
||||
if not response:
|
||||
break
|
||||
yield response
|
||||
if not next_url:
|
||||
break
|
||||
response, next_url = self.get(full_url=next_url)
|
||||
|
||||
@@ -1,17 +1,82 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import NoReturn
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.canvas.access import get_course_permissions
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
|
||||
"""Map Canvas API errors to connector framework exceptions."""
|
||||
if e.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Canvas API token is invalid or expired (HTTP 401)."
|
||||
)
|
||||
elif e.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Canvas API token does not have sufficient permissions (HTTP 403)."
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
|
||||
)
|
||||
elif e.status_code >= 500:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Canvas API error (status={e.status_code}): {e}"
|
||||
)
|
||||
|
||||
|
||||
class CanvasCourse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
course_code: str
|
||||
created_at: str
|
||||
workflow_state: str
|
||||
name: str | None = None
|
||||
course_code: str | None = None
|
||||
created_at: str | None = None
|
||||
workflow_state: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any]) -> "CanvasCourse":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
name=payload.get("name"),
|
||||
course_code=payload.get("course_code"),
|
||||
created_at=payload.get("created_at"),
|
||||
workflow_state=payload.get("workflow_state"),
|
||||
)
|
||||
|
||||
|
||||
class CanvasPage(BaseModel):
|
||||
@@ -19,10 +84,22 @@ class CanvasPage(BaseModel):
|
||||
url: str
|
||||
title: str
|
||||
body: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
course_id: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasPage":
|
||||
return cls(
|
||||
page_id=payload["page_id"],
|
||||
url=payload["url"],
|
||||
title=payload["title"],
|
||||
body=payload.get("body"),
|
||||
created_at=payload.get("created_at"),
|
||||
updated_at=payload.get("updated_at"),
|
||||
course_id=course_id,
|
||||
)
|
||||
|
||||
|
||||
class CanvasAssignment(BaseModel):
|
||||
id: int
|
||||
@@ -30,10 +107,23 @@ class CanvasAssignment(BaseModel):
|
||||
description: str | None = None
|
||||
html_url: str
|
||||
course_id: int
|
||||
created_at: str
|
||||
updated_at: str
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
due_at: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAssignment":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
name=payload["name"],
|
||||
description=payload.get("description"),
|
||||
html_url=payload["html_url"],
|
||||
course_id=course_id,
|
||||
created_at=payload.get("created_at"),
|
||||
updated_at=payload.get("updated_at"),
|
||||
due_at=payload.get("due_at"),
|
||||
)
|
||||
|
||||
|
||||
class CanvasAnnouncement(BaseModel):
|
||||
id: int
|
||||
@@ -43,6 +133,17 @@ class CanvasAnnouncement(BaseModel):
|
||||
posted_at: str | None = None
|
||||
course_id: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAnnouncement":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
title=payload["title"],
|
||||
message=payload.get("message"),
|
||||
html_url=payload["html_url"],
|
||||
posted_at=payload.get("posted_at"),
|
||||
course_id=course_id,
|
||||
)
|
||||
|
||||
|
||||
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
|
||||
|
||||
@@ -72,3 +173,286 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
self.current_course_index += 1
|
||||
self.stage = "pages"
|
||||
self.next_url = None
|
||||
|
||||
|
||||
class CanvasConnector(
|
||||
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
canvas_base_url: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.canvas_base_url = canvas_base_url.rstrip("/").removesuffix("/api/v1")
|
||||
self.batch_size = batch_size
|
||||
self._canvas_client: CanvasApiClient | None = None
|
||||
self._course_permissions_cache: dict[int, ExternalAccess | None] = {}
|
||||
|
||||
@property
|
||||
def canvas_client(self) -> CanvasApiClient:
|
||||
if self._canvas_client is None:
|
||||
raise ConnectorMissingCredentialError("Canvas")
|
||||
return self._canvas_client
|
||||
|
||||
def _get_course_permissions(self, course_id: int) -> ExternalAccess | None:
|
||||
"""Get course permissions with caching."""
|
||||
if course_id not in self._course_permissions_cache:
|
||||
self._course_permissions_cache[course_id] = get_course_permissions(
|
||||
canvas_client=self.canvas_client,
|
||||
course_id=course_id,
|
||||
)
|
||||
return self._course_permissions_cache[course_id]
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_courses(self) -> list[CanvasCourse]:
|
||||
"""Fetch all courses accessible to the authenticated user."""
|
||||
logger.debug("Fetching Canvas courses")
|
||||
|
||||
courses: list[CanvasCourse] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
"courses", params={"per_page": "100", "state[]": "available"}
|
||||
):
|
||||
courses.extend(CanvasCourse.from_api(c) for c in page)
|
||||
return courses
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_pages(self, course_id: int) -> list[CanvasPage]:
|
||||
"""Fetch all pages for a given course."""
|
||||
logger.debug(f"Fetching pages for course {course_id}")
|
||||
|
||||
pages: list[CanvasPage] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
f"courses/{course_id}/pages",
|
||||
params={"per_page": "100", "include[]": "body", "published": "true"},
|
||||
):
|
||||
pages.extend(CanvasPage.from_api(p, course_id=course_id) for p in page)
|
||||
return pages
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_assignments(self, course_id: int) -> list[CanvasAssignment]:
|
||||
"""Fetch all assignments for a given course."""
|
||||
logger.debug(f"Fetching assignments for course {course_id}")
|
||||
|
||||
assignments: list[CanvasAssignment] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
f"courses/{course_id}/assignments",
|
||||
params={"per_page": "100", "published": "true"},
|
||||
):
|
||||
assignments.extend(
|
||||
CanvasAssignment.from_api(a, course_id=course_id) for a in page
|
||||
)
|
||||
return assignments
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_announcements(self, course_id: int) -> list[CanvasAnnouncement]:
|
||||
"""Fetch all announcements for a given course."""
|
||||
logger.debug(f"Fetching announcements for course {course_id}")
|
||||
|
||||
announcements: list[CanvasAnnouncement] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
"announcements",
|
||||
params={
|
||||
"per_page": "100",
|
||||
"context_codes[]": f"course_{course_id}",
|
||||
"active_only": "true",
|
||||
},
|
||||
):
|
||||
announcements.extend(
|
||||
CanvasAnnouncement.from_api(a, course_id=course_id) for a in page
|
||||
)
|
||||
return announcements
|
||||
|
||||
def _build_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
link: str,
|
||||
text: str,
|
||||
semantic_identifier: str,
|
||||
doc_updated_at: datetime | None,
|
||||
course_id: int,
|
||||
doc_type: str,
|
||||
) -> Document:
|
||||
"""Build a Document with standard Canvas fields."""
|
||||
return Document(
|
||||
id=doc_id,
|
||||
sections=cast(
|
||||
list[TextSection | ImageSection],
|
||||
[TextSection(link=link, text=text)],
|
||||
),
|
||||
source=DocumentSource.CANVAS,
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=doc_updated_at,
|
||||
metadata={"course_id": str(course_id), "type": doc_type},
|
||||
)
|
||||
|
||||
def _convert_page_to_document(self, page: CanvasPage) -> Document:
|
||||
"""Convert a Canvas page to a Document."""
|
||||
link = f"{self.canvas_base_url}/courses/{page.course_id}/pages/{page.url}"
|
||||
|
||||
text_parts = [page.title]
|
||||
body_text = parse_html_page_basic(page.body) if page.body else ""
|
||||
if body_text:
|
||||
text_parts.append(body_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
if page.updated_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
|
||||
link=link,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=page.title or f"Page {page.page_id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=page.course_id,
|
||||
doc_type="page",
|
||||
)
|
||||
return document
|
||||
|
||||
def _convert_assignment_to_document(self, assignment: CanvasAssignment) -> Document:
|
||||
"""Convert a Canvas assignment to a Document."""
|
||||
text_parts = [assignment.name]
|
||||
desc_text = (
|
||||
parse_html_page_basic(assignment.description)
|
||||
if assignment.description
|
||||
else ""
|
||||
)
|
||||
if desc_text:
|
||||
text_parts.append(desc_text)
|
||||
if assignment.due_at:
|
||||
due_dt = datetime.fromisoformat(
|
||||
assignment.due_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
assignment.updated_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if assignment.updated_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-assignment-{assignment.course_id}-{assignment.id}",
|
||||
link=assignment.html_url,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=assignment.name or f"Assignment {assignment.id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=assignment.course_id,
|
||||
doc_type="assignment",
|
||||
)
|
||||
return document
|
||||
|
||||
def _convert_announcement_to_document(
|
||||
self, announcement: CanvasAnnouncement
|
||||
) -> Document:
|
||||
"""Convert a Canvas announcement to a Document."""
|
||||
text_parts = [announcement.title]
|
||||
msg_text = (
|
||||
parse_html_page_basic(announcement.message) if announcement.message else ""
|
||||
)
|
||||
if msg_text:
|
||||
text_parts.append(msg_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
announcement.posted_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if announcement.posted_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-announcement-{announcement.course_id}-{announcement.id}",
|
||||
link=announcement.html_url,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=announcement.title or f"Announcement {announcement.id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=announcement.course_id,
|
||||
doc_type="announcement",
|
||||
)
|
||||
return document
|
||||
|
||||
@override
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load and validate Canvas credentials."""
|
||||
access_token = credentials.get("canvas_access_token")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("Canvas")
|
||||
|
||||
try:
|
||||
client = CanvasApiClient(
|
||||
bearer_token=access_token,
|
||||
canvas_base_url=self.canvas_base_url,
|
||||
)
|
||||
client.get("courses", params={"per_page": "1"})
|
||||
except ValueError as e:
|
||||
raise ConnectorValidationError(f"Invalid Canvas base URL: {e}")
|
||||
except OnyxError as e:
|
||||
_handle_canvas_api_error(e)
|
||||
|
||||
self._canvas_client = client
|
||||
return None
|
||||
|
||||
@override
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Canvas connector settings by testing API access."""
|
||||
try:
|
||||
self.canvas_client.get("courses", params={"per_page": "1"})
|
||||
logger.info("Canvas connector settings validated successfully")
|
||||
except OnyxError as e:
|
||||
_handle_canvas_api_error(e)
|
||||
except ConnectorMissingCredentialError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error during Canvas settings validation: {exc}"
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
# TODO(benwu408): implemented in PR4 (perm sync)
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -11,11 +11,13 @@ from discord import Client
|
||||
from discord.channel import TextChannel
|
||||
from discord.channel import Thread
|
||||
from discord.enums import MessageType
|
||||
from discord.errors import LoginFailure
|
||||
from discord.flags import Intents
|
||||
from discord.message import Message as DiscordMessage
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import CredentialInvalidError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -209,8 +211,19 @@ def _manage_async_retrieval(
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
asyncio.create_task(discord_client.start(token))
|
||||
await discord_client.wait_until_ready()
|
||||
start_task = asyncio.create_task(discord_client.start(token))
|
||||
ready_task = asyncio.create_task(discord_client.wait_until_ready())
|
||||
|
||||
done, _ = await asyncio.wait(
|
||||
{start_task, ready_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
# start() runs indefinitely once connected, so it only lands
|
||||
# in `done` when login/connection failed — propagate the error.
|
||||
if start_task in done:
|
||||
ready_task.cancel()
|
||||
start_task.result()
|
||||
|
||||
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||
discord_client=discord_client,
|
||||
@@ -276,6 +289,19 @@ class DiscordConnector(PollConnector, LoadConnector):
|
||||
self._discord_bot_token = credentials["discord_bot_token"]
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
client = Client(intents=Intents.default())
|
||||
try:
|
||||
loop.run_until_complete(client.login(self.discord_bot_token))
|
||||
except LoginFailure as e:
|
||||
raise CredentialInvalidError(f"Invalid Discord bot token: {e}")
|
||||
finally:
|
||||
loop.run_until_complete(client.close())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def _manage_doc_batching(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
|
||||
@@ -72,6 +72,10 @@ CONNECTOR_CLASS_MAP = {
|
||||
module_path="onyx.connectors.coda.connector",
|
||||
class_name="CodaConnector",
|
||||
),
|
||||
DocumentSource.CANVAS: ConnectorMapping(
|
||||
module_path="onyx.connectors.canvas.connector",
|
||||
class_name="CanvasConnector",
|
||||
),
|
||||
DocumentSource.NOTION: ConnectorMapping(
|
||||
module_path="onyx.connectors.notion.connector",
|
||||
class_name="NotionConnector",
|
||||
|
||||
@@ -8,7 +8,6 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
@@ -132,32 +131,47 @@ def get_chat_sessions_by_user(
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(ChatSession.project_id == project_id)
|
||||
elif only_non_project_chats:
|
||||
stmt = stmt.where(ChatSession.project_id.is_(None))
|
||||
|
||||
if not include_failed_chats:
|
||||
non_system_message_exists_subq = (
|
||||
exists()
|
||||
.where(ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
# Leeway for newly created chats that don't have messages yet
|
||||
time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
recently_created = ChatSession.time_created >= time
|
||||
|
||||
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
|
||||
# When filtering out failed chats, we apply the limit in Python after
|
||||
# filtering rather than in SQL, since the post-filter may remove rows.
|
||||
if limit and include_failed_chats:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = result.scalars().all()
|
||||
chat_sessions = list(result.scalars().all())
|
||||
|
||||
return list(chat_sessions)
|
||||
if not include_failed_chats and chat_sessions:
|
||||
# Filter out "failed" sessions (those with only SYSTEM messages)
|
||||
# using a separate efficient query instead of a correlated EXISTS
|
||||
# subquery, which causes full sequential scans of chat_message.
|
||||
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
|
||||
|
||||
if session_ids:
|
||||
valid_session_ids_stmt = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.where(ChatMessage.chat_session_id.in_(session_ids))
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.distinct()
|
||||
)
|
||||
valid_session_ids = set(
|
||||
db_session.execute(valid_session_ids_stmt).scalars().all()
|
||||
)
|
||||
|
||||
chat_sessions = [
|
||||
cs
|
||||
for cs in chat_sessions
|
||||
if cs.time_created >= leeway or cs.id in valid_session_ids
|
||||
]
|
||||
|
||||
if limit:
|
||||
chat_sessions = chat_sessions[:limit]
|
||||
|
||||
return chat_sessions
|
||||
|
||||
|
||||
def delete_orphaned_search_docs(db_session: Session) -> None:
|
||||
|
||||
@@ -215,6 +215,7 @@ class UserFileStatus(str, PyEnum):
|
||||
PROCESSING = "PROCESSING"
|
||||
INDEXING = "INDEXING"
|
||||
COMPLETED = "COMPLETED"
|
||||
SKIPPED = "SKIPPED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
DELETING = "DELETING"
|
||||
|
||||
@@ -7,6 +7,7 @@ from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
@@ -17,6 +18,7 @@ from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
@@ -34,9 +36,19 @@ class CategorizedFilesResult(BaseModel):
|
||||
user_files: list[UserFile]
|
||||
rejected_files: list[RejectedFile]
|
||||
id_to_temp_id: dict[str, str]
|
||||
# Filenames that should be stored but not indexed.
|
||||
skip_indexing_filenames: set[str] = Field(default_factory=set)
|
||||
# Allow SQLAlchemy ORM models inside this result container
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@property
|
||||
def indexable_files(self) -> list[UserFile]:
|
||||
return [
|
||||
uf
|
||||
for uf in self.user_files
|
||||
if (uf.name or "") not in self.skip_indexing_filenames
|
||||
]
|
||||
|
||||
|
||||
def build_hashed_file_key(file: UploadFile) -> str:
|
||||
name_prefix = (file.filename or "")[:50]
|
||||
@@ -70,6 +82,7 @@ def create_user_files(
|
||||
)
|
||||
if new_temp_id is not None:
|
||||
id_to_temp_id[str(new_id)] = new_temp_id
|
||||
should_skip = (file.filename or "") in categorized_files.skip_indexing
|
||||
new_file = UserFile(
|
||||
id=new_id,
|
||||
user_id=user.id,
|
||||
@@ -81,6 +94,7 @@ def create_user_files(
|
||||
link_url=link_url,
|
||||
content_type=file.content_type,
|
||||
file_type=file.content_type,
|
||||
status=UserFileStatus.SKIPPED if should_skip else UserFileStatus.PROCESSING,
|
||||
last_accessed_at=datetime.datetime.now(datetime.timezone.utc),
|
||||
)
|
||||
# Persist the UserFile first to satisfy FK constraints for association table
|
||||
@@ -98,6 +112,7 @@ def create_user_files(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
skip_indexing_filenames=categorized_files.skip_indexing,
|
||||
)
|
||||
|
||||
|
||||
@@ -123,6 +138,7 @@ def upload_files_to_user_files_with_indexing(
|
||||
user_files = categorized_files_result.user_files
|
||||
rejected_files = categorized_files_result.rejected_files
|
||||
id_to_temp_id = categorized_files_result.id_to_temp_id
|
||||
indexable_files = categorized_files_result.indexable_files
|
||||
# Trigger per-file processing immediately for the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
for rejected_file in rejected_files:
|
||||
@@ -134,12 +150,12 @@ def upload_files_to_user_files_with_indexing(
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in user_files:
|
||||
for user_file in indexable_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in user_files:
|
||||
for user_file in indexable_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
@@ -155,6 +171,7 @@ def upload_files_to_user_files_with_indexing(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
skip_indexing_filenames=categorized_files_result.skip_indexing_filenames,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -932,7 +932,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
def search_for_document_ids(
|
||||
self,
|
||||
body: dict[str, Any],
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
|
||||
) -> list[str]:
|
||||
"""Searches the index and returns only document chunk IDs.
|
||||
|
||||
|
||||
@@ -60,8 +60,7 @@ class OpenSearchSearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
RANDOM = "random"
|
||||
ID_RETRIEVAL = "id_retrieval"
|
||||
DOCUMENT_IDS = "document_ids"
|
||||
DOC_ID_RETRIEVAL = "doc_id_retrieval"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import httpx
|
||||
from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -738,6 +739,9 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
_flush_chunks(current_chunks)
|
||||
current_doc_id = doc_id
|
||||
current_chunks = [chunk]
|
||||
elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH:
|
||||
_flush_chunks(current_chunks)
|
||||
current_chunks = [chunk]
|
||||
else:
|
||||
current_chunks.append(chunk)
|
||||
|
||||
@@ -924,7 +928,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.ID_RETRIEVAL,
|
||||
search_type=OpenSearchSearchType.DOC_ID_RETRIEVAL,
|
||||
)
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
|
||||
@@ -10,6 +10,7 @@ import httpx
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
|
||||
from onyx.configs.app_configs import RERANK_COUNT
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
@@ -427,7 +428,9 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
new_document_id_to_original_document_id,
|
||||
all_cleaned_doc_ids,
|
||||
)
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
for chunk_batch in batch_generator(
|
||||
cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH)
|
||||
):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
index_name=self._index_name,
|
||||
|
||||
@@ -15,6 +15,7 @@ PLAIN_TEXT_MIME_TYPE = "text/plain"
|
||||
class OnyxMimeTypes:
|
||||
IMAGE_MIME_TYPES = {"image/jpg", "image/jpeg", "image/png", "image/webp"}
|
||||
CSV_MIME_TYPES = {"text/csv"}
|
||||
TABULAR_MIME_TYPES = CSV_MIME_TYPES | {SPREADSHEET_MIME_TYPE}
|
||||
TEXT_MIME_TYPES = {
|
||||
PLAIN_TEXT_MIME_TYPE,
|
||||
"text/markdown",
|
||||
@@ -34,13 +35,12 @@ class OnyxMimeTypes:
|
||||
PDF_MIME_TYPE,
|
||||
WORD_PROCESSING_MIME_TYPE,
|
||||
PRESENTATION_MIME_TYPE,
|
||||
SPREADSHEET_MIME_TYPE,
|
||||
"message/rfc822",
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, TABULAR_MIME_TYPES
|
||||
)
|
||||
|
||||
EXCLUDED_IMAGE_TYPES = {
|
||||
@@ -53,6 +53,11 @@ class OnyxMimeTypes:
|
||||
|
||||
|
||||
class OnyxFileExtensions:
|
||||
TABULAR_EXTENSIONS = {
|
||||
".csv",
|
||||
".tsv",
|
||||
".xlsx",
|
||||
}
|
||||
PLAIN_TEXT_EXTENSIONS = {
|
||||
".txt",
|
||||
".md",
|
||||
|
||||
@@ -13,15 +13,21 @@ class ChatFileType(str, Enum):
|
||||
DOC = "document"
|
||||
# Plain text only contain the text
|
||||
PLAIN_TEXT = "plain_text"
|
||||
CSV = "csv"
|
||||
# Tabular data files (CSV, XLSX)
|
||||
TABULAR = "tabular"
|
||||
|
||||
def is_text_file(self) -> bool:
|
||||
return self in (
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.DOC,
|
||||
ChatFileType.CSV,
|
||||
ChatFileType.TABULAR,
|
||||
)
|
||||
|
||||
def use_metadata_only(self) -> bool:
|
||||
"""File types where we can ignore the file content
|
||||
and only use the metadata."""
|
||||
return self in (ChatFileType.TABULAR,)
|
||||
|
||||
|
||||
class FileDescriptor(TypedDict):
|
||||
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
|
||||
|
||||
@@ -110,16 +110,20 @@ def load_user_file(file_id: UUID, db_session: Session) -> InMemoryChatFile:
|
||||
# check for plain text normalized version first, then use original file otherwise
|
||||
try:
|
||||
file_io = file_store.read_file(plaintext_file_name, mode="b")
|
||||
# For plaintext versions, use PLAIN_TEXT type (unless it's an image which doesn't have plaintext)
|
||||
plaintext_chat_file_type = (
|
||||
ChatFileType.PLAIN_TEXT
|
||||
if chat_file_type != ChatFileType.IMAGE
|
||||
else chat_file_type
|
||||
)
|
||||
|
||||
# if we have plaintext for image (which happens when image extraction is enabled), we use PLAIN_TEXT type
|
||||
if file_io is not None:
|
||||
# Metadata-only file types preserve their original type so
|
||||
# downstream injection paths can route them correctly.
|
||||
if chat_file_type.use_metadata_only():
|
||||
plaintext_chat_file_type = chat_file_type
|
||||
elif file_io is not None:
|
||||
# if we have plaintext for image (which happens when image
|
||||
# extraction is enabled), we use PLAIN_TEXT type
|
||||
plaintext_chat_file_type = ChatFileType.PLAIN_TEXT
|
||||
else:
|
||||
plaintext_chat_file_type = (
|
||||
ChatFileType.PLAIN_TEXT
|
||||
if chat_file_type != ChatFileType.IMAGE
|
||||
else chat_file_type
|
||||
)
|
||||
|
||||
chat_file = InMemoryChatFile(
|
||||
file_id=str(user_file.file_id),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -7,10 +6,7 @@ from shared_configs.configs import MULTI_TENANT
|
||||
def require_hook_enabled() -> None:
|
||||
"""FastAPI dependency that gates all hook management endpoints.
|
||||
|
||||
Hooks are only available in single-tenant / self-hosted deployments with
|
||||
HOOK_ENABLED=true explicitly set. Two layers of protection:
|
||||
1. MULTI_TENANT check — rejects even if HOOK_ENABLED is accidentally set true
|
||||
2. HOOK_ENABLED flag — explicit opt-in by the operator
|
||||
Hooks are only available in single-tenant / self-hosted EE deployments.
|
||||
|
||||
Use as: Depends(require_hook_enabled)
|
||||
"""
|
||||
@@ -19,8 +15,3 @@ def require_hook_enabled() -> None:
|
||||
OnyxErrorCode.SINGLE_TENANT_ONLY,
|
||||
"Hooks are not available in multi-tenant deployments",
|
||||
)
|
||||
if not HOOK_ENABLED:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.ENV_VAR_GATED,
|
||||
"Hooks are not enabled. Set HOOK_ENABLED=true to enable.",
|
||||
)
|
||||
|
||||
@@ -1,79 +1,22 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
"""CE hook executor.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
HookSkipped and HookSoftFailed are real classes kept here because
|
||||
process_message.py (CE code) uses isinstance checks against them.
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is a validated Pydantic model instance (response_type)
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
execute_hook is the public entry point. It dispatches to _execute_hook_impl
|
||||
via fetch_versioned_implementation so that:
|
||||
- CE: onyx.hooks.executor._execute_hook_impl → no-op, returns HookSkipped()
|
||||
- EE: ee.onyx.hooks.executor._execute_hook_impl → real HTTP call
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
class HookSkipped:
|
||||
@@ -87,277 +30,15 @@ class HookSoftFailed:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if not HOOKS_AVAILABLE:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
def _execute_hook_impl(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _execute_hook_inner(
|
||||
hook: Hook,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSoftFailed:
|
||||
"""Make the HTTP call, validate the response, and return a typed model.
|
||||
|
||||
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
|
||||
"""
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(
|
||||
timeout=timeout, follow_redirects=False
|
||||
) as client: # SSRF guard: never follow redirects
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
|
||||
# Validate the response payload against response_type.
|
||||
# A validation failure downgrades the outcome to a failure so it is logged,
|
||||
# is_reachable is left unchanged (server responded — just a bad payload),
|
||||
# and fail_strategy is respected below.
|
||||
validated_model: T | None = None
|
||||
if outcome.is_success and outcome.response_payload is not None:
|
||||
try:
|
||||
validated_model = response_type.model_validate(outcome.response_payload)
|
||||
except ValidationError as e:
|
||||
msg = (
|
||||
f"Hook response failed validation against {response_type.__name__}: {e}"
|
||||
)
|
||||
outcome = _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=outcome.status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
|
||||
if validated_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"validated_model is None for successful hook call (hook_id={hook_id})",
|
||||
)
|
||||
return validated_model
|
||||
db_session: Session, # noqa: ARG001
|
||||
hook_point: HookPoint, # noqa: ARG001
|
||||
payload: dict[str, Any], # noqa: ARG001
|
||||
response_type: type[T], # noqa: ARG001
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""CE no-op — hooks are not available without EE."""
|
||||
return HookSkipped()
|
||||
|
||||
|
||||
def execute_hook(
|
||||
@@ -367,25 +48,15 @@ def execute_hook(
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously.
|
||||
"""Execute the hook for the given hook point.
|
||||
|
||||
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
|
||||
hook failed with SOFT fail strategy, or a validated response model on success.
|
||||
Raises OnyxError on HARD failure or if the hook is misconfigured.
|
||||
Dispatches to the versioned implementation so EE gets the real executor
|
||||
and CE gets the no-op stub, without any changes at the call site.
|
||||
"""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
fail_strategy = hook.fail_strategy
|
||||
hook_id = hook.id
|
||||
|
||||
try:
|
||||
return _execute_hook_inner(hook, payload, response_type)
|
||||
except Exception:
|
||||
if fail_strategy == HookFailStrategy.SOFT:
|
||||
logger.exception(
|
||||
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
raise
|
||||
impl = fetch_versioned_implementation("onyx.hooks.executor", "_execute_hook_impl")
|
||||
return impl(
|
||||
db_session=db_session,
|
||||
hook_point=hook_point,
|
||||
payload=payload,
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT
|
||||
@@ -19,7 +19,8 @@ from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.models import BuildMetadataAwareChunksResult
|
||||
from onyx.indexing.models import ChunkEnrichmentContext
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
@@ -85,14 +86,21 @@ class DocumentIndexingBatchAdapter:
|
||||
) as transaction:
|
||||
yield transaction
|
||||
|
||||
def build_metadata_aware_chunks(
|
||||
def prepare_enrichment(
|
||||
self,
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
context: DocumentBatchPrepareContext,
|
||||
) -> BuildMetadataAwareChunksResult:
|
||||
"""Enrich chunks with access, document sets, boosts, token counts, and hierarchy."""
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> "DocumentChunkEnricher":
|
||||
"""Do all DB lookups once and return a per-chunk enricher."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
doc_id: 0 for doc_id in updatable_ids
|
||||
}
|
||||
for chunk in chunks:
|
||||
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
@@ -102,67 +110,30 @@ class DocumentIndexingBatchAdapter:
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_access_info = get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
)
|
||||
doc_id_to_document_set = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
return DocumentChunkEnricher(
|
||||
doc_id_to_access_info=get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int] = {
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
document_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
doc_id: 0 for doc_id in updatable_ids
|
||||
}
|
||||
for chunk in chunks_with_embeddings:
|
||||
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
|
||||
# Get ancestor hierarchy node IDs for each document
|
||||
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
|
||||
context.updatable_docs, tenant_id
|
||||
)
|
||||
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
|
||||
document_sets=set(
|
||||
doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
context.id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in context.id_to_boost_map
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
ancestor_hierarchy_node_ids=doc_id_to_ancestor_ids[
|
||||
chunk.source_document.id
|
||||
],
|
||||
)
|
||||
for chunk_num, chunk in enumerate(chunks_with_embeddings)
|
||||
]
|
||||
|
||||
return BuildMetadataAwareChunksResult(
|
||||
chunks=access_aware_chunks,
|
||||
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
|
||||
user_file_id_to_raw_text={},
|
||||
user_file_id_to_token_count={},
|
||||
),
|
||||
doc_id_to_document_set={
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
)
|
||||
},
|
||||
doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents(
|
||||
context.updatable_docs, tenant_id
|
||||
),
|
||||
id_to_boost_map=context.id_to_boost_map,
|
||||
doc_id_to_previous_chunk_cnt={
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
document_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
},
|
||||
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
|
||||
no_access=no_access,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
def _get_ancestor_ids_for_documents(
|
||||
@@ -203,7 +174,7 @@ class DocumentIndexingBatchAdapter:
|
||||
context: DocumentBatchPrepareContext,
|
||||
updatable_chunk_data: list[UpdatableChunkData],
|
||||
filtered_documents: list[Document],
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
) -> None:
|
||||
"""Finalize DB updates, store plaintext, and mark docs as indexed."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
@@ -227,7 +198,7 @@ class DocumentIndexingBatchAdapter:
|
||||
|
||||
update_docs_chunk_count__no_commit(
|
||||
document_ids=updatable_ids,
|
||||
doc_id_to_chunk_count=result.doc_id_to_new_chunk_cnt,
|
||||
doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
|
||||
@@ -249,3 +220,52 @@ class DocumentIndexingBatchAdapter:
|
||||
)
|
||||
|
||||
self.db_session.commit()
|
||||
|
||||
|
||||
class DocumentChunkEnricher:
|
||||
"""Pre-computed metadata for per-chunk enrichment of connector documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
doc_id_to_access_info: dict[str, DocumentAccess],
|
||||
doc_id_to_document_set: dict[str, list[str]],
|
||||
doc_id_to_ancestor_ids: dict[str, list[int]],
|
||||
id_to_boost_map: dict[str, int],
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int],
|
||||
doc_id_to_new_chunk_cnt: dict[str, int],
|
||||
no_access: DocumentAccess,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
self._doc_id_to_access_info = doc_id_to_access_info
|
||||
self._doc_id_to_document_set = doc_id_to_document_set
|
||||
self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids
|
||||
self._id_to_boost_map = id_to_boost_map
|
||||
self._no_access = no_access
|
||||
self._tenant_id = tenant_id
|
||||
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
|
||||
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=self._doc_id_to_access_info.get(
|
||||
chunk.source_document.id, self._no_access
|
||||
),
|
||||
document_sets=set(
|
||||
self._doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
self._id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in self._id_to_boost_map
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=self._tenant_id,
|
||||
aggregated_chunk_boost_factor=score,
|
||||
ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[
|
||||
chunk.source_document.id
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from uuid import UUID
|
||||
|
||||
@@ -24,7 +27,8 @@ from onyx.db.user_file import fetch_persona_ids_for_user_files
|
||||
from onyx.db.user_file import fetch_user_project_ids_for_user_files
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.models import BuildMetadataAwareChunksResult
|
||||
from onyx.indexing.models import ChunkEnrichmentContext
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
@@ -102,13 +106,20 @@ class UserFileIndexingAdapter:
|
||||
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}"
|
||||
)
|
||||
|
||||
def build_metadata_aware_chunks(
|
||||
def prepare_enrichment(
|
||||
self,
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
context: DocumentBatchPrepareContext,
|
||||
) -> BuildMetadataAwareChunksResult:
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> UserFileChunkEnricher:
|
||||
"""Do all DB lookups and pre-compute file metadata from chunks."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int)
|
||||
content_by_file: dict[str, list[str]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
content_by_file[chunk.source_document.id].append(chunk.content)
|
||||
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
@@ -118,7 +129,6 @@ class UserFileIndexingAdapter:
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
user_file_id_to_project_ids = fetch_user_project_ids_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
@@ -139,17 +149,6 @@ class UserFileIndexingAdapter:
|
||||
)
|
||||
}
|
||||
|
||||
user_file_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
user_file_id: len(
|
||||
[
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == user_file_id
|
||||
]
|
||||
)
|
||||
for user_file_id in updatable_ids
|
||||
}
|
||||
|
||||
# Initialize tokenizer used for token count calculation
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
@@ -164,15 +163,9 @@ class UserFileIndexingAdapter:
|
||||
user_file_id_to_raw_text: dict[str, str] = {}
|
||||
user_file_id_to_token_count: dict[str, int | None] = {}
|
||||
for user_file_id in updatable_ids:
|
||||
user_file_chunks = [
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == user_file_id
|
||||
]
|
||||
if user_file_chunks:
|
||||
combined_content = " ".join(
|
||||
[chunk.content for chunk in user_file_chunks]
|
||||
)
|
||||
contents = content_by_file.get(user_file_id)
|
||||
if contents:
|
||||
combined_content = " ".join(contents)
|
||||
user_file_id_to_raw_text[str(user_file_id)] = combined_content
|
||||
token_count: int = (
|
||||
count_tokens(combined_content, llm_tokenizer)
|
||||
@@ -184,28 +177,16 @@ class UserFileIndexingAdapter:
|
||||
user_file_id_to_raw_text[str(user_file_id)] = ""
|
||||
user_file_id_to_token_count[str(user_file_id)] = None
|
||||
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=user_file_id_to_access.get(chunk.source_document.id, no_access),
|
||||
document_sets=set(),
|
||||
user_project=user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
)
|
||||
for chunk_num, chunk in enumerate(chunks_with_embeddings)
|
||||
]
|
||||
|
||||
return BuildMetadataAwareChunksResult(
|
||||
chunks=access_aware_chunks,
|
||||
return UserFileChunkEnricher(
|
||||
user_file_id_to_access=user_file_id_to_access,
|
||||
user_file_id_to_project_ids=user_file_id_to_project_ids,
|
||||
user_file_id_to_persona_ids=user_file_id_to_persona_ids,
|
||||
doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=user_file_id_to_new_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
|
||||
user_file_id_to_raw_text=user_file_id_to_raw_text,
|
||||
user_file_id_to_token_count=user_file_id_to_token_count,
|
||||
no_access=no_access,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
def _notify_assistant_owners_if_files_ready(
|
||||
@@ -249,8 +230,9 @@ class UserFileIndexingAdapter:
|
||||
context: DocumentBatchPrepareContext,
|
||||
updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002
|
||||
filtered_documents: list[Document], # noqa: ARG002
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
) -> None:
|
||||
assert isinstance(enrichment, UserFileChunkEnricher)
|
||||
user_file_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
user_files = (
|
||||
@@ -266,8 +248,10 @@ class UserFileIndexingAdapter:
|
||||
user_file.last_project_sync_at = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
)
|
||||
user_file.chunk_count = result.doc_id_to_new_chunk_cnt[str(user_file.id)]
|
||||
user_file.token_count = result.user_file_id_to_token_count[
|
||||
user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get(
|
||||
str(user_file.id), 0
|
||||
)
|
||||
user_file.token_count = enrichment.user_file_id_to_token_count[
|
||||
str(user_file.id)
|
||||
]
|
||||
|
||||
@@ -279,8 +263,54 @@ class UserFileIndexingAdapter:
|
||||
# Store the plaintext in the file store for faster retrieval
|
||||
# NOTE: this creates its own session to avoid committing the overall
|
||||
# transaction.
|
||||
for user_file_id, raw_text in result.user_file_id_to_raw_text.items():
|
||||
for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items():
|
||||
store_user_file_plaintext(
|
||||
user_file_id=UUID(user_file_id),
|
||||
plaintext_content=raw_text,
|
||||
)
|
||||
|
||||
|
||||
class UserFileChunkEnricher:
|
||||
"""Pre-computed metadata for per-chunk enrichment of user-uploaded files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_file_id_to_access: dict[str, DocumentAccess],
|
||||
user_file_id_to_project_ids: dict[str, list[int]],
|
||||
user_file_id_to_persona_ids: dict[str, list[int]],
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int],
|
||||
doc_id_to_new_chunk_cnt: dict[str, int],
|
||||
user_file_id_to_raw_text: dict[str, str],
|
||||
user_file_id_to_token_count: dict[str, int | None],
|
||||
no_access: DocumentAccess,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
self._user_file_id_to_access = user_file_id_to_access
|
||||
self._user_file_id_to_project_ids = user_file_id_to_project_ids
|
||||
self._user_file_id_to_persona_ids = user_file_id_to_persona_ids
|
||||
self._no_access = no_access
|
||||
self._tenant_id = tenant_id
|
||||
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
|
||||
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
|
||||
self.user_file_id_to_raw_text = user_file_id_to_raw_text
|
||||
self.user_file_id_to_token_count = user_file_id_to_token_count
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=self._user_file_id_to_access.get(
|
||||
chunk.source_document.id, self._no_access
|
||||
),
|
||||
document_sets=set(),
|
||||
user_project=self._user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
personas=self._user_file_id_to_persona_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=self._tenant_id,
|
||||
aggregated_chunk_boost_factor=score,
|
||||
)
|
||||
|
||||
89
backend/onyx/indexing/chunk_batch_store.py
Normal file
89
backend/onyx/indexing/chunk_batch_store.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
class ChunkBatchStore:
|
||||
"""Manages serialization of embedded chunks to a temporary directory.
|
||||
|
||||
Owns the temp directory lifetime and provides save/load/stream/scrub
|
||||
operations.
|
||||
|
||||
Use as a context manager to ensure cleanup::
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
store.save(chunks, batch_idx=0)
|
||||
for chunk in store.stream():
|
||||
...
|
||||
"""
|
||||
|
||||
_EXT = ".pkl"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tmpdir: Path | None = None
|
||||
|
||||
# -- context manager -----------------------------------------------------
|
||||
|
||||
def __enter__(self) -> "ChunkBatchStore":
|
||||
self._tmpdir = Path(tempfile.mkdtemp(prefix="onyx_embeddings_"))
|
||||
return self
|
||||
|
||||
def __exit__(self, *_exc: object) -> None:
|
||||
if self._tmpdir is not None:
|
||||
shutil.rmtree(self._tmpdir, ignore_errors=True)
|
||||
self._tmpdir = None
|
||||
|
||||
@property
|
||||
def _dir(self) -> Path:
|
||||
assert self._tmpdir is not None, "ChunkBatchStore used outside context manager"
|
||||
return self._tmpdir
|
||||
|
||||
# -- storage primitives --------------------------------------------------
|
||||
|
||||
def save(self, chunks: list[IndexChunk], batch_idx: int) -> None:
|
||||
"""Serialize a batch of embedded chunks to disk."""
|
||||
with open(self._dir / f"batch_{batch_idx}{self._EXT}", "wb") as f:
|
||||
pickle.dump(chunks, f)
|
||||
|
||||
def _load(self, batch_file: Path) -> list[IndexChunk]:
|
||||
"""Deserialize a batch of embedded chunks from a file."""
|
||||
with open(batch_file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def _batch_files(self) -> list[Path]:
|
||||
"""Return batch files sorted by numeric index."""
|
||||
return sorted(
|
||||
self._dir.glob(f"batch_*{self._EXT}"),
|
||||
key=lambda p: int(p.stem.removeprefix("batch_")),
|
||||
)
|
||||
|
||||
# -- higher-level operations ---------------------------------------------
|
||||
|
||||
def stream(self) -> Iterator[IndexChunk]:
|
||||
"""Yield all chunks across all batch files.
|
||||
|
||||
Each call returns a fresh generator, so the data can be iterated
|
||||
multiple times (e.g. once per document index).
|
||||
"""
|
||||
for batch_file in self._batch_files():
|
||||
yield from self._load(batch_file)
|
||||
|
||||
def scrub_failed_docs(self, failed_doc_ids: set[str]) -> None:
|
||||
"""Remove chunks belonging to *failed_doc_ids* from all batch files.
|
||||
|
||||
When a document fails embedding in batch N, earlier batches may
|
||||
already contain successfully embedded chunks for that document.
|
||||
This ensures the output is all-or-nothing per document.
|
||||
"""
|
||||
for batch_file in self._batch_files():
|
||||
batch_chunks = self._load(batch_file)
|
||||
cleaned = [
|
||||
c for c in batch_chunks if c.source_document.id not in failed_doc_ids
|
||||
]
|
||||
if len(cleaned) != len(batch_chunks):
|
||||
with open(batch_file, "wb") as f:
|
||||
pickle.dump(cleaned, f)
|
||||
@@ -1,5 +1,8 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -9,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
|
||||
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
|
||||
@@ -43,10 +47,12 @@ from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.indexing.chunk_batch_store import ChunkBatchStore
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
@@ -63,6 +69,7 @@ from onyx.natural_language_processing.utils import tokenizer_trim_middle
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
|
||||
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -91,6 +98,20 @@ class IndexingPipelineResult(BaseModel):
|
||||
|
||||
failures: list[ConnectorFailure]
|
||||
|
||||
@classmethod
|
||||
def empty(cls, total_docs: int) -> "IndexingPipelineResult":
|
||||
return cls(
|
||||
new_docs=0,
|
||||
total_docs=total_docs,
|
||||
total_chunks=0,
|
||||
failures=[],
|
||||
)
|
||||
|
||||
|
||||
class ChunkEmbeddingResult(BaseModel):
|
||||
successful_chunk_ids: list[tuple[int, str]] # (chunk_id, document_id)
|
||||
connector_failures: list[ConnectorFailure]
|
||||
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
@@ -139,6 +160,110 @@ def _upsert_documents_in_db(
|
||||
)
|
||||
|
||||
|
||||
def _get_failed_doc_ids(failures: list[ConnectorFailure]) -> set[str]:
|
||||
"""Extract document IDs from a list of connector failures."""
|
||||
return {f.failed_document.document_id for f in failures if f.failed_document}
|
||||
|
||||
|
||||
def _embed_chunks_to_store(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
tenant_id: str,
|
||||
request_id: str | None,
|
||||
store: ChunkBatchStore,
|
||||
) -> ChunkEmbeddingResult:
|
||||
"""Embed chunks in batches, spilling each batch to *store*.
|
||||
|
||||
If a document fails embedding in any batch, its chunks are excluded from
|
||||
all batches (including earlier ones already written) so that the output
|
||||
is all-or-nothing per document.
|
||||
"""
|
||||
successful_chunk_ids: list[tuple[int, str]] = []
|
||||
all_embedding_failures: list[ConnectorFailure] = []
|
||||
# Track failed doc IDs across all batches so that a failure in batch N
|
||||
# causes chunks for that doc to be skipped in batch N+1 and stripped
|
||||
# from earlier batches.
|
||||
all_failed_doc_ids: set[str] = set()
|
||||
|
||||
for batch_idx, chunk_batch in enumerate(
|
||||
batch_generator(chunks, MAX_CHUNKS_PER_DOC_BATCH)
|
||||
):
|
||||
# Skip chunks belonging to documents that failed in earlier batches.
|
||||
chunk_batch = [
|
||||
c for c in chunk_batch if c.source_document.id not in all_failed_doc_ids
|
||||
]
|
||||
if not chunk_batch:
|
||||
continue
|
||||
|
||||
logger.debug(f"Embedding batch {batch_idx}: {len(chunk_batch)} chunks")
|
||||
|
||||
chunks_with_embeddings, embedding_failures = embed_chunks_with_failure_handling(
|
||||
chunks=chunk_batch,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
all_embedding_failures.extend(embedding_failures)
|
||||
all_failed_doc_ids.update(_get_failed_doc_ids(embedding_failures))
|
||||
|
||||
# Only keep successfully embedded chunks for non-failed docs.
|
||||
chunks_with_embeddings = [
|
||||
c
|
||||
for c in chunks_with_embeddings
|
||||
if c.source_document.id not in all_failed_doc_ids
|
||||
]
|
||||
|
||||
successful_chunk_ids.extend(
|
||||
(c.chunk_id, c.source_document.id) for c in chunks_with_embeddings
|
||||
)
|
||||
|
||||
store.save(chunks_with_embeddings, batch_idx)
|
||||
del chunks_with_embeddings
|
||||
|
||||
# Scrub earlier batches for docs that failed in later batches.
|
||||
if all_failed_doc_ids:
|
||||
store.scrub_failed_docs(all_failed_doc_ids)
|
||||
successful_chunk_ids = [
|
||||
(chunk_id, doc_id)
|
||||
for chunk_id, doc_id in successful_chunk_ids
|
||||
if doc_id not in all_failed_doc_ids
|
||||
]
|
||||
|
||||
return ChunkEmbeddingResult(
|
||||
successful_chunk_ids=successful_chunk_ids,
|
||||
connector_failures=all_embedding_failures,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def embed_and_stream(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
tenant_id: str,
|
||||
request_id: str | None,
|
||||
) -> Generator[tuple[ChunkEmbeddingResult, ChunkBatchStore], None, None]:
|
||||
"""Embed chunks to disk and yield a ``(result, store)`` pair.
|
||||
|
||||
The store owns the temp directory — files are cleaned up when the context
|
||||
manager exits.
|
||||
|
||||
Usage::
|
||||
|
||||
with embed_and_stream(chunks, embedder, tenant_id, req_id) as (result, store):
|
||||
for chunk in store.stream():
|
||||
...
|
||||
"""
|
||||
with ChunkBatchStore() as store:
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
store=store,
|
||||
)
|
||||
yield result, store
|
||||
|
||||
|
||||
def get_doc_ids_to_update(
|
||||
documents: list[Document], db_docs: list[DBDocument]
|
||||
) -> list[Document]:
|
||||
@@ -637,6 +762,29 @@ def add_contextual_summaries(
|
||||
return chunks
|
||||
|
||||
|
||||
def _verify_indexing_completeness(
|
||||
insertion_records: list[DocumentInsertionRecord],
|
||||
write_failures: list[ConnectorFailure],
|
||||
embedding_failed_doc_ids: set[str],
|
||||
updatable_ids: list[str],
|
||||
document_index_name: str,
|
||||
) -> None:
|
||||
"""Verify that every updatable document was either indexed or reported as failed."""
|
||||
all_returned_doc_ids = (
|
||||
{r.document_id for r in insertion_records}
|
||||
| {f.failed_document.document_id for f in write_failures if f.failed_document}
|
||||
| embedding_failed_doc_ids
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
f"This should never happen. "
|
||||
f"This occured for document index {document_index_name}"
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
@@ -672,12 +820,7 @@ def index_doc_batch(
|
||||
filtered_documents = filter_fnc(document_batch)
|
||||
context = adapter.prepare(filtered_documents, ignore_time_skip)
|
||||
if not context:
|
||||
return IndexingPipelineResult(
|
||||
new_docs=0,
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=0,
|
||||
failures=[],
|
||||
)
|
||||
return IndexingPipelineResult.empty(len(filtered_documents))
|
||||
|
||||
# Convert documents to IndexingDocument objects with processed section
|
||||
# logger.debug("Processing image sections")
|
||||
@@ -716,117 +859,99 @@ def index_doc_batch(
|
||||
)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings, embedding_failures = (
|
||||
embed_chunks_with_failure_handling(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
if chunks
|
||||
else ([], [])
|
||||
)
|
||||
with embed_and_stream(chunks, embedder, tenant_id, request_id) as (
|
||||
embedding_result,
|
||||
chunk_store,
|
||||
):
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
updatable_chunk_data = [
|
||||
UpdatableChunkData(
|
||||
chunk_id=chunk_id,
|
||||
document_id=document_id,
|
||||
boost_score=1.0,
|
||||
)
|
||||
for chunk_id, document_id in embedding_result.successful_chunk_ids
|
||||
]
|
||||
|
||||
chunk_content_scores = [1.0] * len(chunks_with_embeddings)
|
||||
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
updatable_chunk_data = [
|
||||
UpdatableChunkData(
|
||||
chunk_id=chunk.chunk_id,
|
||||
document_id=chunk.source_document.id,
|
||||
boost_score=score,
|
||||
)
|
||||
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
|
||||
]
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
# with Vespa can occur.
|
||||
with adapter.lock_context(context.updatable_docs):
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
# we still write data here for the immediate and most likely correct sync, but
|
||||
# to resolve this, an update of the last modified field at the end of this loop
|
||||
# always triggers a final metadata sync via the celery queue
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=chunks_with_embeddings,
|
||||
chunk_content_scores=chunk_content_scores,
|
||||
tenant_id=tenant_id,
|
||||
context=context,
|
||||
embedding_failed_doc_ids = _get_failed_doc_ids(
|
||||
embedding_result.connector_failures
|
||||
)
|
||||
|
||||
short_descriptor_list = [chunk.to_short_descriptor() for chunk in result.chunks]
|
||||
short_descriptor_log = str(short_descriptor_list)[:1024]
|
||||
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
|
||||
# Filter to only successfully embedded chunks so
|
||||
# doc_id_to_new_chunk_cnt reflects what's actually written to Vespa.
|
||||
embedded_chunks = [
|
||||
c for c in chunks if c.source_document.id not in embedding_failed_doc_ids
|
||||
]
|
||||
|
||||
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
|
||||
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
|
||||
for document_index in document_indices:
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
(
|
||||
insertion_records,
|
||||
vector_db_write_failures,
|
||||
) = write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
chunks=result.chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
),
|
||||
# Acquires a lock on the documents so that no other process can modify
|
||||
# them. Not needed until here, since this is when the actual race
|
||||
# condition with vector db can occur.
|
||||
with adapter.lock_context(context.updatable_docs):
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=tenant_id,
|
||||
chunks=embedded_chunks,
|
||||
)
|
||||
|
||||
all_returned_doc_ids: set[str] = (
|
||||
{record.document_id for record in insertion_records}
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in vector_db_write_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in embedding_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
index_batch_params = IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
"This should never happen."
|
||||
f"This occured for document index {document_index.__class__.__name__}"
|
||||
)
|
||||
# We treat the first document index we got as the primary one used
|
||||
# for reporting the state of indexing.
|
||||
if primary_doc_idx_insertion_records is None:
|
||||
primary_doc_idx_insertion_records = insertion_records
|
||||
if primary_doc_idx_vector_db_write_failures is None:
|
||||
primary_doc_idx_vector_db_write_failures = vector_db_write_failures
|
||||
|
||||
adapter.post_index(
|
||||
context=context,
|
||||
updatable_chunk_data=updatable_chunk_data,
|
||||
filtered_documents=filtered_documents,
|
||||
result=result,
|
||||
)
|
||||
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = (
|
||||
None
|
||||
)
|
||||
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = (
|
||||
None
|
||||
)
|
||||
|
||||
for document_index in document_indices:
|
||||
|
||||
def _enriched_stream() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for chunk in chunk_store.stream():
|
||||
yield enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
insertion_records, write_failures = (
|
||||
write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
make_chunks=_enriched_stream,
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
)
|
||||
|
||||
_verify_indexing_completeness(
|
||||
insertion_records=insertion_records,
|
||||
write_failures=write_failures,
|
||||
embedding_failed_doc_ids=embedding_failed_doc_ids,
|
||||
updatable_ids=updatable_ids,
|
||||
document_index_name=document_index.__class__.__name__,
|
||||
)
|
||||
# We treat the first document index we got as the primary one used
|
||||
# for reporting the state of indexing.
|
||||
if primary_doc_idx_insertion_records is None:
|
||||
primary_doc_idx_insertion_records = insertion_records
|
||||
if primary_doc_idx_vector_db_write_failures is None:
|
||||
primary_doc_idx_vector_db_write_failures = write_failures
|
||||
|
||||
adapter.post_index(
|
||||
context=context,
|
||||
updatable_chunk_data=updatable_chunk_data,
|
||||
filtered_documents=filtered_documents,
|
||||
enrichment=enricher,
|
||||
)
|
||||
|
||||
assert primary_doc_idx_insertion_records is not None
|
||||
assert primary_doc_idx_vector_db_write_failures is not None
|
||||
return IndexingPipelineResult(
|
||||
new_docs=len(
|
||||
[r for r in primary_doc_idx_insertion_records if not r.already_existed]
|
||||
new_docs=sum(
|
||||
1 for r in primary_doc_idx_insertion_records if not r.already_existed
|
||||
),
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=len(chunks_with_embeddings),
|
||||
failures=primary_doc_idx_vector_db_write_failures + embedding_failures,
|
||||
total_chunks=len(embedding_result.successful_chunk_ids),
|
||||
failures=primary_doc_idx_vector_db_write_failures
|
||||
+ embedding_result.connector_failures,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -235,12 +235,16 @@ class UpdatableChunkData(BaseModel):
|
||||
boost_score: float
|
||||
|
||||
|
||||
class BuildMetadataAwareChunksResult(BaseModel):
|
||||
chunks: list[DocMetadataAwareIndexChunk]
|
||||
class ChunkEnrichmentContext(Protocol):
|
||||
"""Returned by prepare_enrichment. Holds pre-computed metadata lookups
|
||||
and provides per-chunk enrichment."""
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
user_file_id_to_raw_text: dict[str, str]
|
||||
user_file_id_to_token_count: dict[str, int | None]
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk: ...
|
||||
|
||||
|
||||
class IndexingBatchAdapter(Protocol):
|
||||
@@ -254,18 +258,24 @@ class IndexingBatchAdapter(Protocol):
|
||||
) -> Generator[TransactionalContext, None, None]:
|
||||
"""Provide a transaction/row-lock context for critical updates."""
|
||||
|
||||
def build_metadata_aware_chunks(
|
||||
def prepare_enrichment(
|
||||
self,
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
context: "DocumentBatchPrepareContext",
|
||||
) -> BuildMetadataAwareChunksResult: ...
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> ChunkEnrichmentContext:
|
||||
"""Prepare per-chunk enrichment data (access, document sets, boost, etc.).
|
||||
|
||||
Precondition: ``chunks`` have already been through the embedding step
|
||||
(i.e. they are ``IndexChunk`` instances with populated embeddings,
|
||||
passed here as the base ``DocAwareChunk`` type).
|
||||
"""
|
||||
...
|
||||
|
||||
def post_index(
|
||||
self,
|
||||
context: "DocumentBatchPrepareContext",
|
||||
updatable_chunk_data: list[UpdatableChunkData],
|
||||
filtered_documents: list[Document],
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
) -> None: ...
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from itertools import groupby
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -28,22 +31,22 @@ def _log_insufficient_storage_error(e: Exception) -> None:
|
||||
|
||||
def write_chunks_to_vector_db_with_backoff(
|
||||
document_index: DocumentIndex,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
make_chunks: Callable[[], Iterable[DocMetadataAwareIndexChunk]],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
|
||||
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
|
||||
goes document by document to isolate the failure(s).
|
||||
|
||||
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
|
||||
vector DB interface assumes that all chunks for a single document are present.
|
||||
vector DB interface assumes that all chunks for a single document are present. The
|
||||
chunks must also be in contiguous batches
|
||||
"""
|
||||
|
||||
# first try to write the chunks to the vector db
|
||||
try:
|
||||
return (
|
||||
list(
|
||||
document_index.index(
|
||||
chunks=chunks,
|
||||
chunks=make_chunks(),
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
),
|
||||
@@ -60,14 +63,23 @@ def write_chunks_to_vector_db_with_backoff(
|
||||
# wait a couple seconds just to give the vector db a chance to recover
|
||||
time.sleep(2)
|
||||
|
||||
# try writing each doc one by one
|
||||
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
chunks_for_docs[chunk.source_document.id].append(chunk)
|
||||
|
||||
insertion_records: list[DocumentInsertionRecord] = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
for doc_id, chunks_for_doc in chunks_for_docs.items():
|
||||
|
||||
def key(chunk: DocMetadataAwareIndexChunk) -> str:
|
||||
return chunk.source_document.id
|
||||
|
||||
seen_doc_ids: set[str] = set()
|
||||
for doc_id, chunks_for_doc in groupby(make_chunks(), key=key):
|
||||
if doc_id in seen_doc_ids:
|
||||
raise RuntimeError(
|
||||
f"Doc chunks are not arriving in order. Current doc_id={doc_id}, seen_doc_ids={list(seen_doc_ids)}"
|
||||
)
|
||||
seen_doc_ids.add(doc_id)
|
||||
|
||||
first_chunk = next(chunks_for_doc)
|
||||
chunks_for_doc = chain([first_chunk], chunks_for_doc)
|
||||
|
||||
try:
|
||||
insertion_records.extend(
|
||||
document_index.index(
|
||||
@@ -87,9 +99,7 @@ def write_chunks_to_vector_db_with_backoff(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=(
|
||||
chunks_for_doc[0].get_link() if chunks_for_doc else None
|
||||
),
|
||||
document_link=first_chunk.get_link(),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
|
||||
@@ -77,7 +77,6 @@ from onyx.server.features.default_assistant.api import (
|
||||
)
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.hierarchy.api import router as hierarchy_router
|
||||
from onyx.server.features.hooks.api import router as hook_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
@@ -439,6 +438,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
@@ -454,7 +454,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
@@ -76,11 +76,18 @@ class CategorizedFiles(BaseModel):
|
||||
acceptable: list[UploadFile] = Field(default_factory=list)
|
||||
rejected: list[RejectedFile] = Field(default_factory=list)
|
||||
acceptable_file_to_token_count: dict[str, int] = Field(default_factory=dict)
|
||||
# Filenames within `acceptable` that should be stored but not indexed.
|
||||
skip_indexing: set[str] = Field(default_factory=set)
|
||||
|
||||
# Allow FastAPI UploadFile instances
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
def _skip_token_threshold(extension: str) -> bool:
|
||||
"""Return True if this file extension should bypass the token limit."""
|
||||
return extension.lower() in OnyxFileExtensions.TABULAR_EXTENSIONS
|
||||
|
||||
|
||||
def _apply_long_side_cap(width: int, height: int, cap: int) -> tuple[int, int]:
|
||||
if max(width, height) <= cap:
|
||||
return width, height
|
||||
@@ -264,7 +271,17 @@ def categorize_uploaded_files(
|
||||
token_count = count_tokens(
|
||||
text_content, tokenizer, token_limit=token_threshold
|
||||
)
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
exceeds_threshold = (
|
||||
token_threshold is not None and token_count > token_threshold
|
||||
)
|
||||
if exceeds_threshold and _skip_token_threshold(extension):
|
||||
# Exempt extensions (e.g. spreadsheets) are accepted
|
||||
# but flagged to skip indexing — only metadata is
|
||||
# injected into the LLM context.
|
||||
results.acceptable.append(upload)
|
||||
results.acceptable_file_to_token_count[filename] = token_count
|
||||
results.skip_indexing.add(filename)
|
||||
elif exceeds_threshold:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
|
||||
@@ -9,8 +9,8 @@ def mime_type_to_chat_file_type(mime_type: str | None) -> ChatFileType:
|
||||
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
|
||||
return ChatFileType.IMAGE
|
||||
|
||||
if mime_type in OnyxMimeTypes.CSV_MIME_TYPES:
|
||||
return ChatFileType.CSV
|
||||
if mime_type in OnyxMimeTypes.TABULAR_MIME_TYPES:
|
||||
return ChatFileType.TABULAR
|
||||
|
||||
if mime_type in OnyxMimeTypes.DOCUMENT_MIME_TYPES:
|
||||
return ChatFileType.DOC
|
||||
|
||||
@@ -21,7 +21,6 @@ from onyx.db.notification import get_notifications
|
||||
from onyx.db.notification import update_notification_last_shown
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
@@ -38,6 +37,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -98,7 +98,7 @@ def fetch_settings(
|
||||
needs_reindexing=needs_reindexing,
|
||||
onyx_craft_enabled=onyx_craft_enabled_for_user,
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
hooks_enabled=HOOKS_AVAILABLE,
|
||||
hooks_enabled=not MULTI_TENANT,
|
||||
version=onyx_version,
|
||||
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
default_user_file_max_upload_size_mb=min(
|
||||
|
||||
@@ -116,7 +116,7 @@ class UserSettings(Settings):
|
||||
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
|
||||
# document sets are unavailable.
|
||||
vector_db_enabled: bool = True
|
||||
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
# True when hooks are available: single-tenant EE deployments only.
|
||||
hooks_enabled: bool = False
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -9,6 +10,7 @@ from typing_extensions import override
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_chat_file_by_id
|
||||
@@ -169,10 +171,13 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
|
||||
|
||||
chat_file = self._load_file(file_id)
|
||||
|
||||
# Only PLAIN_TEXT and CSV are guaranteed to contain actual text bytes.
|
||||
# Only PLAIN_TEXT and TABULAR are guaranteed to contain actual text bytes.
|
||||
# DOC type in a loaded file means plaintext extraction failed and the
|
||||
# content is the original binary (e.g. raw PDF/DOCX bytes).
|
||||
if chat_file.file_type not in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV):
|
||||
if chat_file.file_type not in (
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.TABULAR,
|
||||
):
|
||||
raise ToolCallException(
|
||||
message=f"File {file_id} is not a text file (type={chat_file.file_type})",
|
||||
llm_facing_message=(
|
||||
@@ -181,7 +186,19 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
|
||||
)
|
||||
|
||||
try:
|
||||
full_text = chat_file.content.decode("utf-8", errors="replace")
|
||||
if chat_file.file_type == ChatFileType.PLAIN_TEXT:
|
||||
full_text = chat_file.content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
full_text = (
|
||||
extract_file_text(
|
||||
file=io.BytesIO(chat_file.content),
|
||||
file_name=chat_file.filename or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
or ""
|
||||
)
|
||||
except ToolCallException:
|
||||
raise
|
||||
except Exception:
|
||||
raise ToolCallException(
|
||||
message=f"Failed to decode file {file_id}",
|
||||
|
||||
@@ -271,7 +271,7 @@ fastapi-users-db-sqlalchemy==7.0.0
|
||||
# via onyx
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
fastmcp==3.0.2
|
||||
fastmcp==3.2.0
|
||||
# via onyx
|
||||
fastuuid==0.14.0
|
||||
# via litellm
|
||||
@@ -735,7 +735,7 @@ pyee==13.0.0
|
||||
# via playwright
|
||||
pygithub==2.5.0
|
||||
# via onyx
|
||||
pygments==2.19.2
|
||||
pygments==2.20.0
|
||||
# via rich
|
||||
pyjwt==2.12.0
|
||||
# via
|
||||
@@ -1102,6 +1102,8 @@ tzdata==2025.2
|
||||
# tzlocal
|
||||
tzlocal==5.3.1
|
||||
# via dateparser
|
||||
uncalled-for==0.2.0
|
||||
# via fastmcp
|
||||
unstructured==0.18.27
|
||||
# via onyx
|
||||
unstructured-client==0.42.6
|
||||
|
||||
@@ -349,7 +349,7 @@ pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
# via mcp
|
||||
pygments==2.19.2
|
||||
pygments==2.20.0
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
|
||||
@@ -5,6 +5,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -27,6 +28,9 @@ INTERNAL_SEARCH_TOOL_NAME = "internal_search"
|
||||
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
|
||||
MAX_REQUEST_ATTEMPTS = 5
|
||||
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
|
||||
QUESTION_TIMEOUT_SECONDS = 300
|
||||
QUESTION_RETRY_PAUSE_SECONDS = 30
|
||||
MAX_QUESTION_ATTEMPTS = 3
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -109,6 +113,27 @@ def normalize_api_base(api_base: str) -> str:
|
||||
return f"{normalized}/api"
|
||||
|
||||
|
||||
def load_completed_question_ids(output_file: Path) -> set[str]:
|
||||
if not output_file.exists():
|
||||
return set()
|
||||
|
||||
completed_ids: set[str] = set()
|
||||
with output_file.open("r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
try:
|
||||
record = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
question_id = record.get("question_id")
|
||||
if isinstance(question_id, str) and question_id:
|
||||
completed_ids.add(question_id)
|
||||
|
||||
return completed_ids
|
||||
|
||||
|
||||
def load_questions(questions_file: Path) -> list[QuestionRecord]:
|
||||
if not questions_file.exists():
|
||||
raise FileNotFoundError(f"Questions file not found: {questions_file}")
|
||||
@@ -348,6 +373,7 @@ async def generate_answers(
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
parallelism: int,
|
||||
skipped: int,
|
||||
) -> None:
|
||||
if parallelism < 1:
|
||||
raise ValueError("`--parallelism` must be at least 1.")
|
||||
@@ -382,58 +408,178 @@ async def generate_answers(
|
||||
write_lock = asyncio.Lock()
|
||||
completed = 0
|
||||
successful = 0
|
||||
stuck_count = 0
|
||||
failed_questions: list[FailedQuestionRecord] = []
|
||||
total = len(questions)
|
||||
remaining_count = len(questions)
|
||||
overall_total = remaining_count + skipped
|
||||
question_durations: list[float] = []
|
||||
run_start_time = time.monotonic()
|
||||
|
||||
def print_progress() -> None:
|
||||
avg_time = (
|
||||
sum(question_durations) / len(question_durations)
|
||||
if question_durations
|
||||
else 0.0
|
||||
)
|
||||
elapsed = time.monotonic() - run_start_time
|
||||
eta = avg_time * (remaining_count - completed) / max(parallelism, 1)
|
||||
|
||||
done = skipped + completed
|
||||
bar_width = 30
|
||||
filled = (
|
||||
int(bar_width * done / overall_total)
|
||||
if overall_total
|
||||
else bar_width
|
||||
)
|
||||
bar = "█" * filled + "░" * (bar_width - filled)
|
||||
pct = (done / overall_total * 100) if overall_total else 100.0
|
||||
|
||||
parts = (
|
||||
f"\r{bar} {pct:5.1f}% "
|
||||
f"[{done}/{overall_total}] "
|
||||
f"avg {avg_time:.1f}s/q "
|
||||
f"elapsed {elapsed:.0f}s "
|
||||
f"ETA {eta:.0f}s "
|
||||
f"(ok:{successful} fail:{len(failed_questions)}"
|
||||
)
|
||||
if stuck_count:
|
||||
parts += f" stuck:{stuck_count}"
|
||||
if skipped:
|
||||
parts += f" skip:{skipped}"
|
||||
parts += ")"
|
||||
|
||||
sys.stderr.write(parts)
|
||||
sys.stderr.flush()
|
||||
|
||||
print_progress()
|
||||
|
||||
async def process_question(question_record: QuestionRecord) -> None:
|
||||
nonlocal completed
|
||||
nonlocal successful
|
||||
nonlocal stuck_count
|
||||
|
||||
try:
|
||||
async with semaphore:
|
||||
result = await submit_question(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
internal_search_tool_id=internal_search_tool_id,
|
||||
question_record=question_record,
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(1, MAX_QUESTION_ATTEMPTS + 1):
|
||||
q_start = time.monotonic()
|
||||
try:
|
||||
async with semaphore:
|
||||
result = await asyncio.wait_for(
|
||||
submit_question(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
internal_search_tool_id=internal_search_tool_id,
|
||||
question_record=question_record,
|
||||
),
|
||||
timeout=QUESTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
async with progress_lock:
|
||||
stuck_count += 1
|
||||
logger.warning(
|
||||
"Question %s timed out after %ss (attempt %s/%s, "
|
||||
"total stuck: %s) — retrying in %ss",
|
||||
question_record.question_id,
|
||||
QUESTION_TIMEOUT_SECONDS,
|
||||
attempt,
|
||||
MAX_QUESTION_ATTEMPTS,
|
||||
stuck_count,
|
||||
QUESTION_RETRY_PAUSE_SECONDS,
|
||||
)
|
||||
print_progress()
|
||||
last_error = TimeoutError(
|
||||
f"Timed out after {QUESTION_TIMEOUT_SECONDS}s "
|
||||
f"on attempt {attempt}/{MAX_QUESTION_ATTEMPTS}"
|
||||
)
|
||||
except Exception as exc:
|
||||
await asyncio.sleep(QUESTION_RETRY_PAUSE_SECONDS)
|
||||
continue
|
||||
except Exception as exc:
|
||||
duration = time.monotonic() - q_start
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
question_durations.append(duration)
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"Failed question %s (%s/%s)",
|
||||
question_record.question_id,
|
||||
completed,
|
||||
remaining_count,
|
||||
)
|
||||
print_progress()
|
||||
return
|
||||
|
||||
duration = time.monotonic() - q_start
|
||||
|
||||
async with write_lock:
|
||||
file.write(json.dumps(asdict(result), ensure_ascii=False))
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"Failed question %s (%s/%s)",
|
||||
question_record.question_id,
|
||||
completed,
|
||||
total,
|
||||
)
|
||||
successful += 1
|
||||
question_durations.append(duration)
|
||||
print_progress()
|
||||
return
|
||||
|
||||
async with write_lock:
|
||||
file.write(json.dumps(asdict(result), ensure_ascii=False))
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
# All attempts exhausted due to timeouts
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
successful += 1
|
||||
logger.info("Processed %s/%s questions", completed, total)
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(last_error),
|
||||
)
|
||||
)
|
||||
logger.error(
|
||||
"Question %s failed after %s timeout attempts (%s/%s)",
|
||||
question_record.question_id,
|
||||
MAX_QUESTION_ATTEMPTS,
|
||||
completed,
|
||||
remaining_count,
|
||||
)
|
||||
print_progress()
|
||||
|
||||
await asyncio.gather(
|
||||
*(process_question(question_record) for question_record in questions)
|
||||
)
|
||||
|
||||
# Final newline after progress bar
|
||||
sys.stderr.write("\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
total_elapsed = time.monotonic() - run_start_time
|
||||
avg_time = (
|
||||
sum(question_durations) / len(question_durations)
|
||||
if question_durations
|
||||
else 0.0
|
||||
)
|
||||
stuck_suffix = f", {stuck_count} stuck timeouts" if stuck_count else ""
|
||||
resume_suffix = (
|
||||
f" — {skipped} previously completed, "
|
||||
f"{skipped + successful}/{overall_total} overall"
|
||||
if skipped
|
||||
else ""
|
||||
)
|
||||
logger.info(
|
||||
"Done: %s/%s successful in %.1fs (avg %.1fs/question%s)%s",
|
||||
successful,
|
||||
remaining_count,
|
||||
total_elapsed,
|
||||
avg_time,
|
||||
stuck_suffix,
|
||||
resume_suffix,
|
||||
)
|
||||
|
||||
if failed_questions:
|
||||
logger.warning(
|
||||
"Completed with %s failed questions and %s successful questions.",
|
||||
"%s questions failed:",
|
||||
len(failed_questions),
|
||||
successful,
|
||||
)
|
||||
for failed_question in failed_questions:
|
||||
logger.warning(
|
||||
@@ -453,7 +599,30 @@ def main() -> None:
|
||||
raise ValueError("`--max-questions` must be at least 1 when provided.")
|
||||
questions = questions[: args.max_questions]
|
||||
|
||||
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
|
||||
completed_ids = load_completed_question_ids(args.output_file)
|
||||
logger.info(
|
||||
"Found %s already-answered question IDs in %s",
|
||||
len(completed_ids),
|
||||
args.output_file,
|
||||
)
|
||||
total_before_filter = len(questions)
|
||||
questions = [q for q in questions if q.question_id not in completed_ids]
|
||||
skipped = total_before_filter - len(questions)
|
||||
|
||||
if skipped:
|
||||
logger.info(
|
||||
"Resuming: %s/%s already answered, %s remaining",
|
||||
skipped,
|
||||
total_before_filter,
|
||||
len(questions),
|
||||
)
|
||||
else:
|
||||
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
|
||||
|
||||
if not questions:
|
||||
logger.info("All questions already answered. Nothing to do.")
|
||||
return
|
||||
|
||||
logger.info("Writing answers to %s", args.output_file)
|
||||
|
||||
asyncio.run(
|
||||
@@ -463,6 +632,7 @@ def main() -> None:
|
||||
api_base=api_base,
|
||||
api_key=args.api_key,
|
||||
parallelism=args.parallelism,
|
||||
skipped=skipped,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
External dependency unit tests for UserFileIndexingAdapter metadata writing.
|
||||
|
||||
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
|
||||
Validates that prepare_enrichment produces DocMetadataAwareIndexChunk
|
||||
objects with both `user_project` and `personas` fields populated correctly
|
||||
based on actual DB associations.
|
||||
|
||||
@@ -127,7 +127,7 @@ def _make_index_chunk(user_file: UserFile) -> IndexChunk:
|
||||
|
||||
|
||||
class TestAdapterWritesBothMetadataFields:
|
||||
"""build_metadata_aware_chunks must populate user_project AND personas."""
|
||||
"""prepare_enrichment must populate user_project AND personas."""
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
@@ -153,15 +153,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
doc = chunk.source_document
|
||||
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@@ -190,15 +188,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert project.id in aware_chunk.user_project
|
||||
assert aware_chunk.personas == []
|
||||
|
||||
@@ -229,14 +225,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert project.id in aware_chunk.user_project
|
||||
|
||||
@@ -261,14 +256,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert aware_chunk.personas == []
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@@ -300,12 +294,11 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}
|
||||
|
||||
@@ -1175,7 +1175,7 @@ def test_code_interpreter_receives_chat_files(
|
||||
|
||||
file_descriptor: FileDescriptor = {
|
||||
"id": user_file.file_id,
|
||||
"type": ChatFileType.CSV,
|
||||
"type": ChatFileType.TABULAR,
|
||||
"name": "data.csv",
|
||||
"user_file_id": str(user_file.id),
|
||||
}
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
import mimetypes
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.file import FileManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
@@ -79,3 +85,90 @@ def test_send_message_with_text_file_attachment(admin_user: DATestUser) -> None:
|
||||
assert (
|
||||
"third line" in response.full_message.lower()
|
||||
), "Chat response should contain the contents of the file"
|
||||
|
||||
|
||||
def _set_token_threshold(admin_user: DATestUser, threshold_k: int) -> None:
|
||||
"""Set the file token count threshold via admin settings API."""
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/settings",
|
||||
json={"file_token_count_threshold_k": threshold_k},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def _upload_raw(
|
||||
filename: str,
|
||||
content: bytes,
|
||||
user: DATestUser,
|
||||
) -> dict[str, Any]:
|
||||
"""Upload a file and return the full JSON response (user_files + rejected_files)."""
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
headers = user.headers.copy()
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/user/projects/file/upload",
|
||||
files=[("files", (filename, content, mime_type or "application/octet-stream"))],
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_csv_over_token_threshold_uploaded_not_indexed(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""CSV exceeding token threshold is uploaded (accepted) but skips indexing."""
|
||||
_set_token_threshold(admin_user, threshold_k=1)
|
||||
try:
|
||||
# ~2000 tokens with default tokenizer, well over 1K threshold
|
||||
content = ("x " * 100 + "\n") * 20
|
||||
result = _upload_raw("large.csv", content.encode(), admin_user)
|
||||
|
||||
assert len(result["user_files"]) == 1, "CSV should be accepted"
|
||||
assert len(result["rejected_files"]) == 0, "CSV should not be rejected"
|
||||
assert (
|
||||
result["user_files"][0]["status"] == "SKIPPED"
|
||||
), "CSV over threshold should be SKIPPED (uploaded but not indexed)"
|
||||
assert (
|
||||
result["user_files"][0]["chunk_count"] is None
|
||||
), "Skipped file should have no chunks"
|
||||
finally:
|
||||
_set_token_threshold(admin_user, threshold_k=200)
|
||||
|
||||
|
||||
def test_csv_under_token_threshold_uploaded_and_indexed(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""CSV under token threshold is uploaded and queued for indexing."""
|
||||
_set_token_threshold(admin_user, threshold_k=200)
|
||||
try:
|
||||
content = "col1,col2\na,b\n"
|
||||
result = _upload_raw("small.csv", content.encode(), admin_user)
|
||||
|
||||
assert len(result["user_files"]) == 1, "CSV should be accepted"
|
||||
assert len(result["rejected_files"]) == 0, "CSV should not be rejected"
|
||||
assert (
|
||||
result["user_files"][0]["status"] == "PROCESSING"
|
||||
), "CSV under threshold should be PROCESSING (queued for indexing)"
|
||||
finally:
|
||||
_set_token_threshold(admin_user, threshold_k=200)
|
||||
|
||||
|
||||
def test_txt_over_token_threshold_rejected(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Non-exempt file exceeding token threshold is rejected entirely."""
|
||||
_set_token_threshold(admin_user, threshold_k=1)
|
||||
try:
|
||||
# ~2000 tokens, well over 1K threshold. Unlike CSV, .txt is not
|
||||
# exempt from the threshold so the file should be rejected.
|
||||
content = ("x " * 100 + "\n") * 20
|
||||
result = _upload_raw("big.txt", content.encode(), admin_user)
|
||||
|
||||
assert len(result["user_files"]) == 0, "File should not be accepted"
|
||||
assert len(result["rejected_files"]) == 1, "File should be rejected"
|
||||
assert "token limit" in result["rejected_files"][0]["reason"].lower()
|
||||
finally:
|
||||
_set_token_threshold(admin_user, threshold_k=200)
|
||||
|
||||
0
backend/tests/unit/ee/onyx/hooks/__init__.py
Normal file
0
backend/tests/unit/ee/onyx/hooks/__init__.py
Normal file
@@ -9,11 +9,11 @@ import httpx
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.hooks.executor import _execute_hook_impl as execute_hook
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.query_processing import QueryProcessingResponse
|
||||
@@ -118,28 +118,30 @@ def db_session() -> MagicMock:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hooks_available,hook",
|
||||
"multi_tenant,hook",
|
||||
[
|
||||
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
|
||||
pytest.param(False, None, id="hooks_not_available"),
|
||||
pytest.param(True, None, id="hook_not_found"),
|
||||
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
|
||||
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
|
||||
# MULTI_TENANT=True exits before the DB lookup — hook is irrelevant.
|
||||
pytest.param(True, None, id="multi_tenant"),
|
||||
pytest.param(False, None, id="hook_not_found"),
|
||||
pytest.param(False, _make_hook(is_active=False), id="hook_inactive"),
|
||||
pytest.param(False, _make_hook(endpoint_url=None), id="no_endpoint_url"),
|
||||
],
|
||||
)
|
||||
def test_early_exit_returns_skipped_with_no_db_writes(
|
||||
db_session: MagicMock,
|
||||
hooks_available: bool,
|
||||
multi_tenant: bool,
|
||||
hook: MagicMock | None,
|
||||
) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", multi_tenant),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
@@ -164,14 +166,16 @@ def test_success_returns_validated_model_and_sets_reachable(
|
||||
hook = _make_hook()
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
@@ -195,14 +199,14 @@ def test_success_skips_reachable_write_when_already_true(db_session: MagicMock)
|
||||
hook = _make_hook(is_reachable=True)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
@@ -224,14 +228,16 @@ def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
@@ -258,14 +264,16 @@ def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
@@ -384,14 +392,14 @@ def test_http_failure_paths(
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=exception)
|
||||
@@ -443,14 +451,14 @@ def test_authorization_header(
|
||||
hook = _make_hook(api_key=api_key)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit"),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit"),
|
||||
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
mock_client = _setup_client(mock_client_cls, response=_make_response())
|
||||
@@ -489,13 +497,13 @@ def test_persist_session_failure_is_swallowed(
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_session_with_current_tenant",
|
||||
"ee.onyx.hooks.executor.get_session_with_current_tenant",
|
||||
side_effect=RuntimeError("DB unavailable"),
|
||||
),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
@@ -556,14 +564,16 @@ def test_response_validation_failure_respects_fail_strategy(
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
# Response payload is missing required_field → ValidationError
|
||||
@@ -619,13 +629,13 @@ def test_unexpected_exception_in_inner_respects_fail_strategy(
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch(
|
||||
"onyx.hooks.executor._execute_hook_inner",
|
||||
"ee.onyx.hooks.executor._execute_hook_inner",
|
||||
side_effect=ValueError("unexpected bug"),
|
||||
),
|
||||
):
|
||||
@@ -658,17 +668,19 @@ def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> Non
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.hooks.executor.update_hook__no_commit",
|
||||
"ee.onyx.hooks.executor.update_hook__no_commit",
|
||||
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
|
||||
),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))
|
||||
0
backend/tests/unit/ee/onyx/server/__init__.py
Normal file
0
backend/tests/unit/ee/onyx/server/__init__.py
Normal file
@@ -1,4 +1,4 @@
|
||||
"""Unit tests for onyx.server.features.hooks.api helpers.
|
||||
"""Unit tests for ee.onyx.server.features.hooks.api helpers.
|
||||
|
||||
Covers:
|
||||
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
|
||||
@@ -16,13 +16,13 @@ from unittest.mock import patch
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from ee.onyx.server.features.hooks.api import _check_ssrf_safety
|
||||
from ee.onyx.server.features.hooks.api import _raise_for_validation_failure
|
||||
from ee.onyx.server.features.hooks.api import _validate_endpoint
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.models import HookValidateResponse
|
||||
from onyx.hooks.models import HookValidateStatus
|
||||
from onyx.server.features.hooks.api import _check_ssrf_safety
|
||||
from onyx.server.features.hooks.api import _raise_for_validation_failure
|
||||
from onyx.server.features.hooks.api import _validate_endpoint
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -117,28 +117,28 @@ class TestCheckSsrfSafety:
|
||||
class TestValidateEndpoint:
|
||||
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
|
||||
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
|
||||
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
|
||||
with patch("ee.onyx.server.features.hooks.api._check_ssrf_safety"):
|
||||
return _validate_endpoint(
|
||||
endpoint_url=_URL,
|
||||
api_key=api_key,
|
||||
timeout_seconds=_TIMEOUT,
|
||||
)
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(200)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(500)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize("status_code", [401, 403])
|
||||
def test_401_403_returns_auth_failed(
|
||||
self, mock_client_cls: MagicMock, status_code: int
|
||||
@@ -150,21 +150,21 @@ class TestValidateEndpoint:
|
||||
assert result.status == HookValidateStatus.auth_failed
|
||||
assert str(status_code) in (result.error_message or "")
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(422)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize(
|
||||
"exc",
|
||||
[
|
||||
@@ -179,7 +179,7 @@ class TestValidateEndpoint:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_error_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
@@ -189,7 +189,7 @@ class TestValidateEndpoint:
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_arbitrary_exception_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
@@ -198,7 +198,7 @@ class TestValidateEndpoint:
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = _mock_response(200)
|
||||
@@ -206,7 +206,7 @@ class TestValidateEndpoint:
|
||||
_, kwargs = mock_post.call_args
|
||||
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = _mock_response(200)
|
||||
@@ -300,6 +300,66 @@ class TestExtractContextFiles:
|
||||
assert result.file_texts == []
|
||||
assert result.total_token_count == 50
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_tool_metadata_file_id_matches_chat_history_file_id(
|
||||
self, mock_load: MagicMock
|
||||
) -> None:
|
||||
"""The file_id in tool metadata (from extract_context_files) and the
|
||||
file_id in chat history messages (from build_file_context) must
|
||||
agree, otherwise the LLM sees different IDs for the same file across
|
||||
turns.
|
||||
|
||||
In production, UserFile.id (UUID PK) differs from UserFile.file_id
|
||||
(file-store path). Both pathways should produce the same file_id
|
||||
(UserFile.id) for FileReaderTool."""
|
||||
from onyx.chat.chat_utils import build_file_context
|
||||
|
||||
user_file_uuid = uuid4()
|
||||
file_store_path = f"user_files/{user_file_uuid}/data.csv"
|
||||
|
||||
uf = UserFile(
|
||||
id=user_file_uuid,
|
||||
file_id=file_store_path,
|
||||
name="data.csv",
|
||||
token_count=100,
|
||||
file_type="text/csv",
|
||||
)
|
||||
|
||||
in_memory = InMemoryChatFile(
|
||||
file_id=file_store_path,
|
||||
content=b"col1,col2\na,b",
|
||||
file_type=ChatFileType.TABULAR,
|
||||
filename="data.csv",
|
||||
)
|
||||
|
||||
mock_load.return_value = [in_memory]
|
||||
|
||||
# Pathway 1: extract_context_files (project/persona context)
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
tool_metadata_file_id = result.file_metadata_for_tool[0].file_id
|
||||
|
||||
# Pathway 2: build_file_context (chat history path)
|
||||
# In convert_chat_history, tool_file_id comes from
|
||||
# file_descriptor["user_file_id"], which is str(UserFile.id)
|
||||
ctx = build_file_context(
|
||||
tool_file_id=str(user_file_uuid),
|
||||
filename="data.csv",
|
||||
file_type=ChatFileType.TABULAR,
|
||||
)
|
||||
chat_history_file_id = ctx.tool_metadata.file_id
|
||||
|
||||
# Both pathways must produce the same ID for the LLM
|
||||
assert tool_metadata_file_id == chat_history_file_id, (
|
||||
f"File ID mismatch: extract_context_files uses '{tool_metadata_file_id}' "
|
||||
f"but build_file_context uses '{chat_history_file_id}'."
|
||||
)
|
||||
|
||||
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
|
||||
def test_overflow_with_vector_db_disabled_provides_tool_metadata(self) -> None:
|
||||
"""When vector DB is disabled, overflow produces FileToolMetadata."""
|
||||
@@ -316,6 +376,128 @@ class TestExtractContextFiles:
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "bigfile.txt"
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_metadata_only_files_not_counted_in_aggregate_tokens(
|
||||
self, mock_load: MagicMock
|
||||
) -> None:
|
||||
"""Metadata-only files (TABULAR) should not count toward the token budget."""
|
||||
text_file_id = str(uuid4())
|
||||
text_uf = _make_user_file(token_count=100, file_id=text_file_id)
|
||||
# TABULAR file with large token count — should be excluded from aggregate
|
||||
tabular_uf = _make_user_file(
|
||||
token_count=50000, name="huge.xlsx", file_id=str(uuid4())
|
||||
)
|
||||
tabular_uf.file_type = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
|
||||
mock_load.return_value = [
|
||||
_make_in_memory_file(file_id=text_file_id, content="text content"),
|
||||
InMemoryChatFile(
|
||||
file_id=str(tabular_uf.id),
|
||||
content=b"binary xlsx",
|
||||
file_type=ChatFileType.TABULAR,
|
||||
filename="huge.xlsx",
|
||||
),
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[text_uf, tabular_uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
# Text file fits (100 < 6000), so files should be loaded
|
||||
assert result.file_texts == ["text content"]
|
||||
# TABULAR file should appear as tool metadata, not in file_texts
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "huge.xlsx"
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_metadata_only_files_loaded_as_tool_metadata(
|
||||
self, mock_load: MagicMock
|
||||
) -> None:
|
||||
"""When files fit, metadata-only files appear in file_metadata_for_tool."""
|
||||
text_file_id = str(uuid4())
|
||||
tabular_file_id = str(uuid4())
|
||||
text_uf = _make_user_file(token_count=100, file_id=text_file_id)
|
||||
tabular_uf = _make_user_file(
|
||||
token_count=500, name="data.csv", file_id=tabular_file_id
|
||||
)
|
||||
tabular_uf.file_type = "text/csv"
|
||||
|
||||
mock_load.return_value = [
|
||||
_make_in_memory_file(file_id=text_file_id, content="hello"),
|
||||
InMemoryChatFile(
|
||||
file_id=tabular_file_id,
|
||||
content=b"col1,col2\na,b",
|
||||
file_type=ChatFileType.TABULAR,
|
||||
filename="data.csv",
|
||||
),
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[text_uf, tabular_uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == ["hello"]
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "data.csv"
|
||||
# TABULAR should not appear in file_metadata (that's for citation)
|
||||
assert all(m.filename != "data.csv" for m in result.file_metadata)
|
||||
|
||||
def test_overflow_with_vector_db_preserves_metadata_only_tool_metadata(
|
||||
self,
|
||||
) -> None:
|
||||
"""When text files overflow with vector DB enabled, metadata-only files
|
||||
should still be exposed via file_metadata_for_tool since they aren't
|
||||
in the vector DB and would otherwise be inaccessible."""
|
||||
text_uf = _make_user_file(token_count=7000, name="bigfile.txt")
|
||||
tabular_uf = _make_user_file(token_count=500, name="data.xlsx")
|
||||
tabular_uf.file_type = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[text_uf, tabular_uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
# Text files overflow → search filter enabled
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.file_texts == []
|
||||
# TABULAR file should still be in tool metadata
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "data.xlsx"
|
||||
|
||||
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
|
||||
def test_overflow_no_vector_db_includes_all_files_in_tool_metadata(self) -> None:
|
||||
"""When vector DB is disabled and files overflow, all files
|
||||
(both text and metadata-only) appear in file_metadata_for_tool."""
|
||||
text_uf = _make_user_file(token_count=7000, name="bigfile.txt")
|
||||
tabular_uf = _make_user_file(token_count=500, name="data.xlsx")
|
||||
tabular_uf.file_type = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[text_uf, tabular_uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert len(result.file_metadata_for_tool) == 2
|
||||
filenames = {m.filename for m in result.file_metadata_for_tool}
|
||||
assert filenames == {"bigfile.txt", "data.xlsx"}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Search filter + search_usage determination
|
||||
|
||||
@@ -644,6 +644,92 @@ class TestConstructMessageHistory:
|
||||
assert "Project file 0 content" in project_message.message
|
||||
assert "Project file 1 content" in project_message.message
|
||||
|
||||
def test_file_metadata_for_tool_produces_message(self) -> None:
|
||||
"""When context_files has file_metadata_for_tool, a metadata listing
|
||||
message should be injected into the history."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg = create_message("Analyze the spreadsheet", MessageType.USER, 5)
|
||||
|
||||
context_files = ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=False,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=0,
|
||||
file_metadata_for_tool=[
|
||||
FileToolMetadata(
|
||||
file_id="xlsx-1",
|
||||
filename="report.xlsx",
|
||||
approx_char_count=100000,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=[user_msg],
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
token_counter=_simple_token_counter,
|
||||
)
|
||||
|
||||
# Should have: system, tool_metadata_message, user
|
||||
assert len(result) == 3
|
||||
metadata_msg = result[1]
|
||||
assert metadata_msg.message_type == MessageType.USER
|
||||
assert "report.xlsx" in metadata_msg.message
|
||||
assert "xlsx-1" in metadata_msg.message
|
||||
|
||||
def test_metadata_only_and_text_files_both_present(self) -> None:
|
||||
"""When both text content and tool metadata are present, both messages
|
||||
should appear in the history."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg = create_message("Summarize everything", MessageType.USER, 5)
|
||||
|
||||
context_files = ExtractedContextFiles(
|
||||
file_texts=["Text file content here"],
|
||||
image_files=[],
|
||||
use_as_search_filter=False,
|
||||
total_token_count=100,
|
||||
file_metadata=[
|
||||
ContextFileMetadata(
|
||||
file_id="txt-1",
|
||||
filename="notes.txt",
|
||||
file_content="Text file content here",
|
||||
),
|
||||
],
|
||||
uncapped_token_count=100,
|
||||
file_metadata_for_tool=[
|
||||
FileToolMetadata(
|
||||
file_id="xlsx-1",
|
||||
filename="data.xlsx",
|
||||
approx_char_count=50000,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=[user_msg],
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
available_tokens=2000,
|
||||
token_counter=_simple_token_counter,
|
||||
)
|
||||
|
||||
# Should have: system, context_files_message, tool_metadata_message, user
|
||||
assert len(result) == 4
|
||||
# Context files message (text content)
|
||||
assert "documents" in result[1].message
|
||||
assert "Text file content here" in result[1].message
|
||||
# Tool metadata message
|
||||
assert "data.xlsx" in result[2].message
|
||||
assert result[3] == user_msg
|
||||
|
||||
|
||||
def _simple_token_counter(text: str) -> int:
|
||||
"""Approximate token counter for tests (~4 chars per token)."""
|
||||
|
||||
@@ -139,7 +139,7 @@ def test_csv_file_type() -> None:
|
||||
result = _extract_referenced_file_descriptors([tool_call], message)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == ChatFileType.CSV
|
||||
assert result[0]["type"] == ChatFileType.TABULAR
|
||||
|
||||
|
||||
def test_unknown_extension_defaults_to_plain_text() -> None:
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
"""Tests for Canvas connector — client (PR1)."""
|
||||
"""Tests for Canvas connector — client, credentials, conversion."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.connectors.canvas.connector import CanvasConnector
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -18,6 +26,77 @@ FAKE_BASE_URL = "https://myschool.instructure.com"
|
||||
FAKE_TOKEN = "fake-canvas-token"
|
||||
|
||||
|
||||
def _mock_course(
|
||||
course_id: int = 1,
|
||||
name: str = "Intro to CS",
|
||||
course_code: str = "CS101",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"id": course_id,
|
||||
"name": name,
|
||||
"course_code": course_code,
|
||||
"created_at": "2025-01-01T00:00:00Z",
|
||||
"workflow_state": "available",
|
||||
}
|
||||
|
||||
|
||||
def _build_connector(base_url: str = FAKE_BASE_URL) -> CanvasConnector:
|
||||
"""Build a connector with mocked credential validation."""
|
||||
with patch("onyx.connectors.canvas.client.rl_requests") as mock_req:
|
||||
mock_req.get.return_value = _mock_response(json_data=[_mock_course()])
|
||||
connector = CanvasConnector(canvas_base_url=base_url)
|
||||
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
return connector
|
||||
|
||||
|
||||
def _mock_page(
|
||||
page_id: int = 10,
|
||||
title: str = "Syllabus",
|
||||
updated_at: str = "2025-06-01T12:00:00Z",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"page_id": page_id,
|
||||
"url": "syllabus",
|
||||
"title": title,
|
||||
"body": "<p>Welcome to the course</p>",
|
||||
"created_at": "2025-01-15T00:00:00Z",
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
|
||||
|
||||
def _mock_assignment(
|
||||
assignment_id: int = 20,
|
||||
name: str = "Homework 1",
|
||||
course_id: int = 1,
|
||||
updated_at: str = "2025-06-01T12:00:00Z",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"id": assignment_id,
|
||||
"name": name,
|
||||
"description": "<p>Solve these problems</p>",
|
||||
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/assignments/{assignment_id}",
|
||||
"course_id": course_id,
|
||||
"created_at": "2025-01-20T00:00:00Z",
|
||||
"updated_at": updated_at,
|
||||
"due_at": "2025-02-01T23:59:00Z",
|
||||
}
|
||||
|
||||
|
||||
def _mock_announcement(
|
||||
announcement_id: int = 30,
|
||||
title: str = "Class Cancelled",
|
||||
course_id: int = 1,
|
||||
posted_at: str = "2025-06-01T12:00:00Z",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"id": announcement_id,
|
||||
"title": title,
|
||||
"message": "<p>No class today</p>",
|
||||
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/discussion_topics/{announcement_id}",
|
||||
"posted_at": posted_at,
|
||||
}
|
||||
|
||||
|
||||
def _mock_response(
|
||||
status_code: int = 200,
|
||||
json_data: Any = None,
|
||||
@@ -325,6 +404,57 @@ class TestGet:
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient.paginate tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPaginate:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[{"id": 1}, {"id": 2}]
|
||||
)
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
pages = list(client.paginate("courses"))
|
||||
|
||||
assert len(pages) == 1
|
||||
assert pages[0] == [{"id": 1}, {"id": 2}]
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_two_pages(self, mock_requests: MagicMock) -> None:
|
||||
next_link = f'<{FAKE_BASE_URL}/api/v1/courses?page=2>; rel="next"'
|
||||
page1 = _mock_response(json_data=[{"id": 1}], link_header=next_link)
|
||||
page2 = _mock_response(json_data=[{"id": 2}])
|
||||
mock_requests.get.side_effect = [page1, page2]
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
pages = list(client.paginate("courses"))
|
||||
|
||||
assert len(pages) == 2
|
||||
assert pages[0] == [{"id": 1}]
|
||||
assert pages[1] == [{"id": 2}]
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
pages = list(client.paginate("courses"))
|
||||
|
||||
assert pages == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient._parse_next_link tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -379,3 +509,368 @@ class TestParseNextLink:
|
||||
|
||||
with pytest.raises(OnyxError, match="must use https"):
|
||||
self.client._parse_next_link(header)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — credential loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadCredentials:
|
||||
def _assert_load_credentials_raises(
|
||||
self,
|
||||
status_code: int,
|
||||
expected_error: type[Exception],
|
||||
mock_requests: MagicMock,
|
||||
) -> None:
|
||||
"""Helper: assert load_credentials raises expected_error for a given status."""
|
||||
mock_requests.get.return_value = _mock_response(status_code, {})
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
with pytest.raises(expected_error):
|
||||
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_load_credentials_success(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
|
||||
result = connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
|
||||
assert result is None
|
||||
assert connector._canvas_client is not None
|
||||
|
||||
def test_canvas_client_raises_without_credentials(self) -> None:
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
|
||||
with pytest.raises(ConnectorMissingCredentialError):
|
||||
_ = connector.canvas_client
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_load_credentials_invalid_token(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_load_credentials_raises(401, CredentialExpiredError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_load_credentials_insufficient_permissions(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
self._assert_load_credentials_raises(
|
||||
403, InsufficientPermissionsError, mock_requests
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — URL normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnectorUrlNormalization:
|
||||
def test_strips_api_v1_suffix(self) -> None:
|
||||
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/api/v1")
|
||||
|
||||
result = connector.canvas_base_url
|
||||
expected = FAKE_BASE_URL
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_strips_trailing_slash(self) -> None:
|
||||
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/")
|
||||
|
||||
result = connector.canvas_base_url
|
||||
expected = FAKE_BASE_URL
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_no_change_for_clean_url(self) -> None:
|
||||
connector = _build_connector(base_url=FAKE_BASE_URL)
|
||||
|
||||
result = connector.canvas_base_url
|
||||
expected = FAKE_BASE_URL
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — document conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDocumentConversion:
|
||||
def setup_method(self) -> None:
|
||||
self.connector = _build_connector()
|
||||
|
||||
def test_convert_page_to_document(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasPage
|
||||
|
||||
page = CanvasPage(
|
||||
page_id=10,
|
||||
url="syllabus",
|
||||
title="Syllabus",
|
||||
body="<p>Welcome</p>",
|
||||
created_at="2025-01-15T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_page_to_document(page)
|
||||
|
||||
expected_id = "canvas-page-1-10"
|
||||
expected_metadata = {"course_id": "1", "type": "page"}
|
||||
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
|
||||
|
||||
assert doc.id == expected_id
|
||||
assert doc.source == DocumentSource.CANVAS
|
||||
assert doc.semantic_identifier == "Syllabus"
|
||||
assert doc.metadata == expected_metadata
|
||||
assert doc.sections[0].link is not None
|
||||
assert f"{FAKE_BASE_URL}/courses/1/pages/syllabus" in doc.sections[0].link
|
||||
assert doc.doc_updated_at == expected_updated_at
|
||||
|
||||
def test_convert_page_without_body(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasPage
|
||||
|
||||
page = CanvasPage(
|
||||
page_id=11,
|
||||
url="empty-page",
|
||||
title="Empty Page",
|
||||
body=None,
|
||||
created_at="2025-01-15T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_page_to_document(page)
|
||||
section_text = doc.sections[0].text
|
||||
assert section_text is not None
|
||||
|
||||
assert "Empty Page" in section_text
|
||||
assert "<p>" not in section_text
|
||||
|
||||
def test_convert_assignment_to_document(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAssignment
|
||||
|
||||
assignment = CanvasAssignment(
|
||||
id=20,
|
||||
name="Homework 1",
|
||||
description="<p>Solve these</p>",
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/20",
|
||||
course_id=1,
|
||||
created_at="2025-01-20T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
due_at="2025-02-01T23:59:00Z",
|
||||
)
|
||||
|
||||
doc = self.connector._convert_assignment_to_document(assignment)
|
||||
|
||||
expected_id = "canvas-assignment-1-20"
|
||||
expected_due_text = "Due: February 01, 2025 23:59 UTC"
|
||||
|
||||
assert doc.id == expected_id
|
||||
assert doc.source == DocumentSource.CANVAS
|
||||
assert doc.semantic_identifier == "Homework 1"
|
||||
assert doc.sections[0].text is not None
|
||||
assert expected_due_text in doc.sections[0].text
|
||||
|
||||
def test_convert_assignment_without_description(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAssignment
|
||||
|
||||
assignment = CanvasAssignment(
|
||||
id=21,
|
||||
name="Quiz 1",
|
||||
description=None,
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/21",
|
||||
course_id=1,
|
||||
created_at="2025-01-20T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
due_at=None,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_assignment_to_document(assignment)
|
||||
section_text = doc.sections[0].text
|
||||
assert section_text is not None
|
||||
|
||||
assert "Quiz 1" in section_text
|
||||
assert "Due:" not in section_text
|
||||
|
||||
def test_convert_announcement_to_document(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAnnouncement
|
||||
|
||||
announcement = CanvasAnnouncement(
|
||||
id=30,
|
||||
title="Class Cancelled",
|
||||
message="<p>No class today</p>",
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/30",
|
||||
posted_at="2025-06-01T12:00:00Z",
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_announcement_to_document(announcement)
|
||||
|
||||
expected_id = "canvas-announcement-1-30"
|
||||
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
|
||||
|
||||
assert doc.id == expected_id
|
||||
assert doc.source == DocumentSource.CANVAS
|
||||
assert doc.semantic_identifier == "Class Cancelled"
|
||||
assert doc.doc_updated_at == expected_updated_at
|
||||
|
||||
def test_convert_announcement_without_posted_at(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAnnouncement
|
||||
|
||||
announcement = CanvasAnnouncement(
|
||||
id=31,
|
||||
title="TBD Announcement",
|
||||
message=None,
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/31",
|
||||
posted_at=None,
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_announcement_to_document(announcement)
|
||||
|
||||
assert doc.doc_updated_at is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — validate_connector_settings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateConnectorSettings:
|
||||
def _assert_validate_raises(
|
||||
self,
|
||||
status_code: int,
|
||||
expected_error: type[Exception],
|
||||
mock_requests: MagicMock,
|
||||
) -> None:
|
||||
"""Helper: assert validate_connector_settings raises expected_error."""
|
||||
success_resp = _mock_response(json_data=[_mock_course()])
|
||||
fail_resp = _mock_response(status_code, {})
|
||||
mock_requests.get.side_effect = [success_resp, fail_resp]
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
with pytest.raises(expected_error):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_success(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
|
||||
connector = _build_connector()
|
||||
|
||||
connector.validate_connector_settings() # should not raise
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_expired_credential(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(401, CredentialExpiredError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_insufficient_permissions(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(403, InsufficientPermissionsError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_rate_limited(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(429, ConnectorValidationError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_unexpected_error(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(500, UnexpectedValidationError, mock_requests)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _list_* pagination tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListCourses:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_course(1), _mock_course(2, "CS201", "Data Structures")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_courses()
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].id == 1
|
||||
assert result[1].id == 2
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_courses()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestListPages:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_page(10), _mock_page(11, "Notes")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_pages(course_id=1)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].page_id == 10
|
||||
assert result[1].page_id == 11
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_pages(course_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestListAssignments:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_assignment(20), _mock_assignment(21, "Quiz 1")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_assignments(course_id=1)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].id == 20
|
||||
assert result[1].id == 21
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_assignments(course_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestListAnnouncements:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_announcement(30), _mock_announcement(31, "Update")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_announcements(course_id=1)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].id == 30
|
||||
assert result[1].id == 31
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_announcements(course_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from discord.errors import LoginFailure
|
||||
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.exceptions import CredentialInvalidError
|
||||
|
||||
|
||||
def _build_connector(token: str = "fake-bot-token") -> DiscordConnector:
|
||||
connector = DiscordConnector()
|
||||
connector.load_credentials({"discord_bot_token": token})
|
||||
return connector
|
||||
|
||||
|
||||
@patch("onyx.connectors.discord.connector.Client.close", new_callable=AsyncMock)
|
||||
@patch("onyx.connectors.discord.connector.Client.login", new_callable=AsyncMock)
|
||||
def test_validate_success(
|
||||
mock_login: AsyncMock,
|
||||
mock_close: AsyncMock,
|
||||
) -> None:
|
||||
connector = _build_connector()
|
||||
connector.validate_connector_settings()
|
||||
|
||||
mock_login.assert_awaited_once_with("fake-bot-token")
|
||||
mock_close.assert_awaited_once()
|
||||
|
||||
|
||||
@patch("onyx.connectors.discord.connector.Client.close", new_callable=AsyncMock)
|
||||
@patch(
|
||||
"onyx.connectors.discord.connector.Client.login",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=LoginFailure("Improper token has been passed."),
|
||||
)
|
||||
def test_validate_invalid_token(
|
||||
mock_login: AsyncMock, # noqa: ARG001
|
||||
mock_close: AsyncMock,
|
||||
) -> None:
|
||||
connector = _build_connector(token="bad-token")
|
||||
|
||||
with pytest.raises(CredentialInvalidError, match="Invalid Discord bot token"):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
mock_close.assert_awaited_once()
|
||||
225
backend/tests/unit/onyx/db/test_chat_sessions.py
Normal file
225
backend/tests/unit/onyx/db/test_chat_sessions.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Tests for get_chat_sessions_by_user filtering behavior.
|
||||
|
||||
Verifies that failed chat sessions (those with only SYSTEM messages) are
|
||||
correctly filtered out while preserving recently created sessions, matching
|
||||
the behavior specified in PR #7233.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
def _make_session(
|
||||
user_id: UUID,
|
||||
time_created: datetime | None = None,
|
||||
time_updated: datetime | None = None,
|
||||
description: str = "",
|
||||
) -> MagicMock:
|
||||
"""Create a mock ChatSession with the given attributes."""
|
||||
session = MagicMock(spec=ChatSession)
|
||||
session.id = uuid4()
|
||||
session.user_id = user_id
|
||||
session.time_created = time_created or datetime.now(timezone.utc)
|
||||
session.time_updated = time_updated or session.time_created
|
||||
session.description = description
|
||||
session.deleted = False
|
||||
session.onyxbot_flow = False
|
||||
session.project_id = None
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id() -> UUID:
|
||||
return uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_time() -> datetime:
|
||||
"""A timestamp well outside the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recent_time() -> datetime:
|
||||
"""A timestamp within the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(minutes=2)
|
||||
|
||||
|
||||
class TestGetChatSessionsByUser:
|
||||
"""Tests for the failed chat filtering logic in get_chat_sessions_by_user."""
|
||||
|
||||
def test_filters_out_failed_sessions(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Sessions with only SYSTEM messages should be excluded."""
|
||||
valid_session = _make_session(user_id, time_created=old_time)
|
||||
failed_session = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
# First execute: returns all sessions
|
||||
# Second execute: returns only the valid session's ID (has non-system msgs)
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
valid_session,
|
||||
failed_session,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [valid_session.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == valid_session.id
|
||||
|
||||
def test_keeps_recent_sessions_without_messages(
|
||||
self, user_id: UUID, recent_time: datetime
|
||||
) -> None:
|
||||
"""Recently created sessions should be kept even without messages."""
|
||||
recent_session = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [recent_session]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == recent_session.id
|
||||
# Should only have been called once — no second query needed
|
||||
# because the recent session is within the leeway window
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_include_failed_chats_skips_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""When include_failed_chats=True, no filtering should occur."""
|
||||
session_a = _make_session(user_id, time_created=old_time)
|
||||
session_b = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [session_a, session_b]
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=True,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Only one DB call — no second query for message validation
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_limit_applied_after_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Limit should be applied after filtering, not before."""
|
||||
sessions = [_make_session(user_id, time_created=old_time) for _ in range(5)]
|
||||
valid_ids = [s.id for s in sessions[:3]]
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = sessions
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = valid_ids
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Should be the first 2 valid sessions (order preserved)
|
||||
assert result[0].id == sessions[0].id
|
||||
assert result[1].id == sessions[1].id
|
||||
|
||||
def test_mixed_recent_and_old_sessions(
|
||||
self, user_id: UUID, old_time: datetime, recent_time: datetime
|
||||
) -> None:
|
||||
"""Mix of recent and old sessions should filter correctly."""
|
||||
old_valid = _make_session(user_id, time_created=old_time)
|
||||
old_failed = _make_session(user_id, time_created=old_time)
|
||||
recent_no_msgs = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
old_valid,
|
||||
old_failed,
|
||||
recent_no_msgs,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [old_valid.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
result_ids = {cs.id for cs in result}
|
||||
assert old_valid.id in result_ids
|
||||
assert recent_no_msgs.id in result_ids
|
||||
assert old_failed.id not in result_ids
|
||||
|
||||
def test_empty_result(self, user_id: UUID) -> None:
|
||||
"""No sessions should return empty list without errors."""
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert db_session.execute.call_count == 1
|
||||
@@ -40,6 +40,8 @@ def test_send_task_includes_expires(
|
||||
user_files=user_files,
|
||||
rejected_files=[],
|
||||
id_to_temp_id={},
|
||||
skip_indexing_filenames=set(),
|
||||
indexable_files=user_files,
|
||||
)
|
||||
|
||||
mock_user = MagicMock()
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int,
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
"""Creates a minimal DocMetadataAwareIndexChunk for testing."""
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
sections=[TextSection(text="test", link="http://test.com")],
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier="test_doc",
|
||||
metadata={},
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links={0: "http://test.com"},
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings=ChunkEmbedding(full_embedding=[0.1] * 10, mini_chunk_embeddings=[]),
|
||||
title_embedding=[0.1] * 10,
|
||||
tenant_id="test_tenant",
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
ancestor_hierarchy_node_ids=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_index() -> tuple[OpenSearchDocumentIndex, MagicMock]:
|
||||
"""Creates an OpenSearchDocumentIndex with a mocked client.
|
||||
Returns the index and the mock for bulk_index_documents."""
|
||||
mock_client = MagicMock()
|
||||
mock_bulk = MagicMock()
|
||||
mock_client.bulk_index_documents = mock_bulk
|
||||
|
||||
tenant_state = TenantState(tenant_id="test_tenant", multitenant=False)
|
||||
|
||||
index = OpenSearchDocumentIndex.__new__(OpenSearchDocumentIndex)
|
||||
index._index_name = "test_index"
|
||||
index._client = mock_client
|
||||
index._tenant_state = tenant_state
|
||||
|
||||
return index, mock_bulk
|
||||
|
||||
|
||||
def _make_metadata(doc_id: str, chunk_count: int) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=0,
|
||||
new_chunk_cnt=chunk_count,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_under_batch_limit_flushes_once() -> None:
|
||||
"""A document with fewer chunks than MAX_CHUNKS_PER_DOC_BATCH should flush once."""
|
||||
index, mock_bulk = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 50
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
assert mock_bulk.call_count == 1
|
||||
batch_arg = mock_bulk.call_args_list[0]
|
||||
assert len(batch_arg.kwargs["documents"]) == num_chunks
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_over_batch_limit_flushes_multiple_times() -> None:
|
||||
"""A document with more chunks than MAX_CHUNKS_PER_DOC_BATCH should flush multiple times."""
|
||||
index, mock_bulk = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 250
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 250 chunks / 100 per batch = 3 flushes (100 + 100 + 50)
|
||||
assert mock_bulk.call_count == 3
|
||||
batch_sizes = [len(call.kwargs["documents"]) for call in mock_bulk.call_args_list]
|
||||
assert batch_sizes == [100, 100, 50]
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_exactly_at_batch_limit() -> None:
|
||||
"""A document with exactly MAX_CHUNKS_PER_DOC_BATCH chunks should flush once
|
||||
(the flush happens on the next chunk, not at the boundary)."""
|
||||
index, mock_bulk = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 100
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 100 chunks hit the >= check on chunk 101 which doesn't exist,
|
||||
# so final flush handles all 100
|
||||
# Actually: the elif fires when len(current_chunks) >= 100, which happens
|
||||
# when current_chunks has 100 items and the 101st chunk arrives.
|
||||
# With exactly 100 chunks, the 100th chunk makes len == 99, then appended -> 100.
|
||||
# No 101st chunk arrives, so the final flush handles all 100.
|
||||
assert mock_bulk.call_count == 1
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_one_over_batch_limit() -> None:
|
||||
"""101 chunks for one doc: first 100 flushed when the 101st arrives, then
|
||||
the 101st is flushed at the end."""
|
||||
index, mock_bulk = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 101
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
assert mock_bulk.call_count == 2
|
||||
batch_sizes = [len(call.kwargs["documents"]) for call in mock_bulk.call_args_list]
|
||||
assert batch_sizes == [100, 1]
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_multiple_docs_each_under_limit_flush_per_doc() -> None:
|
||||
"""Multiple documents each under the batch limit should flush once per document."""
|
||||
index, mock_bulk = _make_index()
|
||||
chunks = []
|
||||
for doc_idx in range(3):
|
||||
doc_id = f"doc_{doc_idx}"
|
||||
for chunk_idx in range(50):
|
||||
chunks.append(_make_chunk(doc_id, chunk_idx))
|
||||
|
||||
metadata = IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
f"doc_{i}": IndexingMetadata.ChunkCounts(old_chunk_cnt=0, new_chunk_cnt=50)
|
||||
for i in range(3)
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 3 documents = 3 flushes (one per doc boundary + final)
|
||||
assert mock_bulk.call_count == 3
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_delete_called_once_per_document() -> None:
|
||||
"""Even with multiple flushes for a single document, delete should only be
|
||||
called once per document."""
|
||||
index, _mock_bulk = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 250
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0) as mock_delete:
|
||||
index.index(chunks, metadata)
|
||||
|
||||
mock_delete.assert_called_once_with(doc_id, None)
|
||||
@@ -0,0 +1,152 @@
|
||||
"""Unit tests for VespaDocumentIndex.index().
|
||||
|
||||
These tests mock all external I/O (HTTP calls, thread pools) and verify
|
||||
the streaming logic, ID cleaning/mapping, and DocumentInsertionRecord
|
||||
construction.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int = 0,
|
||||
content: str = "test content",
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_doc",
|
||||
sections=[TextSection(text=content, link=None)],
|
||||
source=DocumentSource.NOT_APPLICABLE,
|
||||
metadata={},
|
||||
)
|
||||
index_chunk = IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb=content[:50],
|
||||
content=content,
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.1] * 10,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=index_chunk,
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
|
||||
def _make_indexing_metadata(
|
||||
doc_ids: list[str],
|
||||
old_counts: list[int],
|
||||
new_counts: list[int],
|
||||
) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=old,
|
||||
new_chunk_cnt=new,
|
||||
)
|
||||
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _stub_enrich(
|
||||
doc_id: str,
|
||||
old_chunk_cnt: int,
|
||||
) -> EnrichedDocumentIndexingInfo:
|
||||
"""Build an EnrichedDocumentIndexingInfo that says 'no chunks to delete'
|
||||
when old_chunk_cnt == 0, or 'has existing chunks' otherwise."""
|
||||
return EnrichedDocumentIndexingInfo(
|
||||
doc_id=doc_id,
|
||||
chunk_start_index=0,
|
||||
old_version=False,
|
||||
chunk_end_index=old_chunk_cnt,
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.BATCH_SIZE",
|
||||
3,
|
||||
)
|
||||
def test_index_respects_batch_size(
|
||||
mock_enrich: MagicMock,
|
||||
mock_get_chunk_ids: MagicMock, # noqa: ARG001
|
||||
mock_delete: MagicMock, # noqa: ARG001
|
||||
mock_batch_index: MagicMock,
|
||||
) -> None:
|
||||
"""When chunks exceed BATCH_SIZE, batch_index_vespa_chunks is called
|
||||
multiple times with correctly sized batches."""
|
||||
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
|
||||
|
||||
index = VespaDocumentIndex(
|
||||
index_name="test_index",
|
||||
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=MagicMock(),
|
||||
)
|
||||
|
||||
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(7)]
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[7])
|
||||
|
||||
results = index.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
# With BATCH_SIZE=3 and 7 chunks: batches of 3, 3, 1
|
||||
assert mock_batch_index.call_count == 3
|
||||
batch_sizes = [len(c.kwargs["chunks"]) for c in mock_batch_index.call_args_list]
|
||||
assert batch_sizes == [3, 3, 1]
|
||||
|
||||
# Verify all chunks are accounted for and in order
|
||||
all_indexed = [
|
||||
chunk for c in mock_batch_index.call_args_list for chunk in c.kwargs["chunks"]
|
||||
]
|
||||
assert len(all_indexed) == 7
|
||||
assert [c.chunk_id for c in all_indexed] == list(range(7))
|
||||
@@ -11,30 +11,13 @@ from onyx.hooks.api_dependencies import require_hook_enabled
|
||||
|
||||
class TestRequireHookEnabled:
|
||||
def test_raises_when_multi_tenant(self) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.api_dependencies.MULTI_TENANT", True),
|
||||
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", True),
|
||||
):
|
||||
with patch("onyx.hooks.api_dependencies.MULTI_TENANT", True):
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
require_hook_enabled()
|
||||
assert exc_info.value.error_code is OnyxErrorCode.SINGLE_TENANT_ONLY
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "multi-tenant" in exc_info.value.detail
|
||||
|
||||
def test_raises_when_flag_disabled(self) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.api_dependencies.MULTI_TENANT", False),
|
||||
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", False),
|
||||
):
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
require_hook_enabled()
|
||||
assert exc_info.value.error_code is OnyxErrorCode.ENV_VAR_GATED
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "HOOK_ENABLED" in exc_info.value.detail
|
||||
|
||||
def test_passes_when_enabled_single_tenant(self) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.api_dependencies.MULTI_TENANT", False),
|
||||
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", True),
|
||||
):
|
||||
def test_passes_when_single_tenant(self) -> None:
|
||||
with patch("onyx.hooks.api_dependencies.MULTI_TENANT", False):
|
||||
require_hook_enabled() # must not raise
|
||||
|
||||
391
backend/tests/unit/onyx/indexing/test_embed_chunks_in_batches.py
Normal file
391
backend/tests/unit/onyx/indexing/test_embed_chunks_in_batches.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""Unit tests for _embed_chunks_to_store.
|
||||
|
||||
Tests cover:
|
||||
- Single batch, no failures
|
||||
- Multiple batches, no failures
|
||||
- Failure in a single batch
|
||||
- Cross-batch document failure scrubbing
|
||||
- Later batches skip already-failed docs
|
||||
- Empty input
|
||||
- All chunks fail
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.chunk_batch_store import ChunkBatchStore
|
||||
from onyx.indexing.indexing_pipeline import _embed_chunks_to_store
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_doc(doc_id: str) -> Document:
|
||||
return Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test",
|
||||
source=DocumentSource.FILE,
|
||||
sections=[TextSection(text="test", link=None)],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
def _make_chunk(doc_id: str, chunk_id: int) -> DocAwareChunk:
|
||||
return DocAwareChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=_make_doc(doc_id),
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
def _make_index_chunk(doc_id: str, chunk_id: int) -> IndexChunk:
|
||||
"""Create an IndexChunk (a DocAwareChunk with embeddings)."""
|
||||
return IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=_make_doc(doc_id),
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.1] * 10,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_failure(doc_id: str) -> ConnectorFailure:
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(document_id=doc_id, document_link=None),
|
||||
failure_message="embedding failed",
|
||||
exception=RuntimeError("embedding failed"),
|
||||
)
|
||||
|
||||
|
||||
def _mock_embed_success(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
"""Simulate successful embedding of all chunks."""
|
||||
return (
|
||||
[_make_index_chunk(c.source_document.id, c.chunk_id) for c in chunks],
|
||||
[],
|
||||
)
|
||||
|
||||
|
||||
def _mock_embed_fail_doc(
|
||||
fail_doc_id: str,
|
||||
) -> Callable[..., tuple[list[IndexChunk], list[ConnectorFailure]]]:
|
||||
"""Return an embed mock that fails all chunks for a specific doc."""
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
successes = [
|
||||
_make_index_chunk(c.source_document.id, c.chunk_id)
|
||||
for c in chunks
|
||||
if c.source_document.id != fail_doc_id
|
||||
]
|
||||
failures = (
|
||||
[_make_failure(fail_doc_id)]
|
||||
if any(c.source_document.id == fail_doc_id for c in chunks)
|
||||
else []
|
||||
)
|
||||
return successes, failures
|
||||
|
||||
return _embed
|
||||
|
||||
|
||||
class TestEmbedChunksInBatches:
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_single_batch_no_failures(self, mock_embed: MagicMock) -> None:
|
||||
"""All chunks fit in one batch and embed successfully."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", i) for i in range(3)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 3
|
||||
assert len(result.connector_failures) == 0
|
||||
|
||||
# Verify stored contents
|
||||
assert len(store._batch_files()) == 1
|
||||
stored = list(store.stream())
|
||||
assert len(stored) == 3
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_multiple_batches_no_failures(self, mock_embed: MagicMock) -> None:
|
||||
"""Chunks are split across multiple batches, all succeed."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", i) for i in range(7)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 7
|
||||
assert len(result.connector_failures) == 0
|
||||
assert len(store._batch_files()) == 3 # 3 + 3 + 1
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_single_batch_with_failure(self, mock_embed: MagicMock) -> None:
|
||||
"""One doc fails embedding, its chunks are excluded from results."""
|
||||
mock_embed.side_effect = _mock_embed_fail_doc("doc2")
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("doc1", 0),
|
||||
_make_chunk("doc2", 1),
|
||||
_make_chunk("doc1", 2),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.connector_failures) == 1
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "doc2" not in successful_doc_ids
|
||||
assert "doc1" in successful_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_cross_batch_failure_scrubs_earlier_batch(
|
||||
self, mock_embed: MagicMock
|
||||
) -> None:
|
||||
"""Doc A spans batches 0 and 1. It succeeds in batch 0 but fails in
|
||||
batch 1. Its chunks should be scrubbed from batch 0's batch file."""
|
||||
call_count = 0
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _mock_embed_success(chunks)
|
||||
else:
|
||||
return _mock_embed_fail_doc("docA")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("docA", 0),
|
||||
_make_chunk("docA", 1),
|
||||
_make_chunk("docA", 2),
|
||||
_make_chunk("docA", 3),
|
||||
_make_chunk("docB", 0),
|
||||
_make_chunk("docB", 1),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# docA should be fully excluded from results
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "docA" not in successful_doc_ids
|
||||
assert "docB" in successful_doc_ids
|
||||
assert len(result.connector_failures) == 1
|
||||
|
||||
# Verify batch 0 was scrubbed of docA chunks
|
||||
all_stored = list(store.stream())
|
||||
stored_doc_ids = {c.source_document.id for c in all_stored}
|
||||
assert "docA" not in stored_doc_ids
|
||||
assert "docB" in stored_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_later_batch_skips_already_failed_doc(self, mock_embed: MagicMock) -> None:
|
||||
"""If docA fails in batch 0, its chunks in batch 1 are skipped
|
||||
entirely (never sent to the embedder)."""
|
||||
embedded_doc_ids: list[str] = []
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
for c in chunks:
|
||||
embedded_doc_ids.append(c.source_document.id)
|
||||
return _mock_embed_fail_doc("docA")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("docA", 0),
|
||||
_make_chunk("docA", 1),
|
||||
_make_chunk("docA", 2),
|
||||
_make_chunk("docA", 3),
|
||||
_make_chunk("docB", 0),
|
||||
_make_chunk("docB", 1),
|
||||
]
|
||||
_embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# docA should only appear in batch 0, not batch 1
|
||||
batch_1_doc_ids = embedded_doc_ids[3:]
|
||||
assert "docA" not in batch_1_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_failed_doc_skipped_in_later_batch_while_other_doc_succeeds(
|
||||
self, mock_embed: MagicMock
|
||||
) -> None:
|
||||
"""doc1 spans batches 0 and 1, doc2 only in batch 1. Batch 0 fails
|
||||
doc1. In batch 1, doc1 chunks should be skipped but doc2 chunks
|
||||
should still be embedded successfully."""
|
||||
embedded_chunks: list[list[str]] = []
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
embedded_chunks.append([c.source_document.id for c in chunks])
|
||||
return _mock_embed_fail_doc("doc1")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("doc1", 0),
|
||||
_make_chunk("doc1", 1),
|
||||
_make_chunk("doc1", 2),
|
||||
_make_chunk("doc1", 3),
|
||||
_make_chunk("doc2", 0),
|
||||
_make_chunk("doc2", 1),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# doc1 should be fully excluded, doc2 fully included
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "doc1" not in successful_doc_ids
|
||||
assert "doc2" in successful_doc_ids
|
||||
assert len(result.successful_chunk_ids) == 2 # doc2's 2 chunks
|
||||
|
||||
# Batch 1 should only contain doc2 (doc1 was filtered before embedding)
|
||||
assert len(embedded_chunks) == 2
|
||||
assert "doc1" not in embedded_chunks[1]
|
||||
assert embedded_chunks[1] == ["doc2", "doc2"]
|
||||
|
||||
# Verify on-disk state has no doc1 chunks
|
||||
all_stored = list(store.stream())
|
||||
assert all(c.source_document.id == "doc2" for c in all_stored)
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
def test_empty_input(self, mock_embed: MagicMock) -> None:
|
||||
"""Empty chunk list produces empty results."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=[],
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 0
|
||||
assert len(result.connector_failures) == 0
|
||||
mock_embed.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_all_chunks_fail(self, mock_embed: MagicMock) -> None:
|
||||
"""When all documents fail, results have no successful chunks."""
|
||||
|
||||
def _fail_all(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
doc_ids = {c.source_document.id for c in chunks}
|
||||
return [], [_make_failure(doc_id) for doc_id in doc_ids]
|
||||
|
||||
mock_embed.side_effect = _fail_all
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", 0), _make_chunk("doc2", 1)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 0
|
||||
assert len(result.connector_failures) == 2
|
||||
@@ -116,7 +116,7 @@ def _run_adapter_build(
|
||||
project_ids_map: dict[str, list[int]],
|
||||
persona_ids_map: dict[str, list[int]],
|
||||
) -> list[DocMetadataAwareIndexChunk]:
|
||||
"""Helper that runs UserFileIndexingAdapter.build_metadata_aware_chunks
|
||||
"""Helper that runs UserFileIndexingAdapter.prepare_enrichment + enrich_chunk
|
||||
with all external dependencies mocked."""
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import (
|
||||
UserFileIndexingAdapter,
|
||||
@@ -155,18 +155,16 @@ def _run_adapter_build(
|
||||
side_effect=Exception("no LLM in tests"),
|
||||
),
|
||||
):
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id="test_tenant",
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id="test_tenant",
|
||||
chunks=[chunk],
|
||||
)
|
||||
|
||||
return result.chunks
|
||||
return [enricher.enrich_chunk(chunk, 1.0)]
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
|
||||
"""UserFileIndexingAdapter.build_metadata_aware_chunks writes persona IDs
|
||||
def test_prepare_enrichment_includes_persona_ids() -> None:
|
||||
"""UserFileIndexingAdapter.prepare_enrichment writes persona IDs
|
||||
fetched from the DB into each chunk's metadata."""
|
||||
file_id = str(uuid4())
|
||||
persona_ids = [5, 12]
|
||||
@@ -183,7 +181,7 @@ def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
|
||||
assert chunks[0].user_project == project_ids
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_missing_file_defaults_to_empty() -> None:
|
||||
def test_prepare_enrichment_missing_file_defaults_to_empty() -> None:
|
||||
"""When a file has no persona or project associations in the DB, the
|
||||
adapter should default to empty lists (not KeyError or None)."""
|
||||
file_id = str(uuid4())
|
||||
|
||||
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
@@ -417,3 +417,57 @@ def test_categorize_text_under_token_limit_accepted(
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable_file_to_token_count["ok.txt"] == 500
|
||||
|
||||
|
||||
# --- skip-indexing vs rejection by file type ---
|
||||
|
||||
|
||||
def test_csv_over_token_threshold_accepted_skip_indexing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""CSV exceeding token threshold is uploaded but flagged to skip indexing."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
|
||||
text = "x" * 2000 # 2000 tokens > 1000 threshold
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: text)
|
||||
|
||||
upload = _make_upload("large.csv", size=2000, content=text.encode())
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable[0].filename == "large.csv"
|
||||
assert "large.csv" in result.skip_indexing
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_csv_under_token_threshold_accepted_and_indexed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""CSV under token threshold is uploaded and indexed normally."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
|
||||
text = "x" * 500 # 500 tokens < 1000 threshold
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: text)
|
||||
|
||||
upload = _make_upload("small.csv", size=500, content=text.encode())
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable[0].filename == "small.csv"
|
||||
assert "small.csv" not in result.skip_indexing
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_pdf_over_token_threshold_rejected(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""PDF exceeding token threshold is rejected entirely (not uploaded)."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
|
||||
text = "x" * 2000 # 2000 tokens > 1000 threshold
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: text)
|
||||
|
||||
upload = _make_upload("big.pdf", size=2000, content=text.encode())
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].filename == "big.pdf"
|
||||
assert "1K token limit" in result.rejected[0].reason
|
||||
assert len(result.acceptable) == 0
|
||||
|
||||
@@ -82,7 +82,7 @@ class TestChatFileConversion:
|
||||
ChatLoadedFile(
|
||||
file_id="file-2",
|
||||
content=b"csv,data\n1,2",
|
||||
file_type=ChatFileType.CSV,
|
||||
file_type=ChatFileType.TABULAR,
|
||||
filename="data.csv",
|
||||
content_text="csv,data\n1,2",
|
||||
token_count=5,
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
# Engineering Principles, Style, and Correctness Guide
|
||||
|
||||
## Principles and collaboration
|
||||
|
||||
- **Use 1-way vs 2-way doors.** For 2-way doors, move faster and iterate. For 1-way doors, be more deliberate.
|
||||
- **Consistency > being “right.”** Prefer consistent patterns across the codebase. If something is truly bad, fix it everywhere.
|
||||
- **Fix what you touch (selectively).**
|
||||
- Don’t feel obligated to fix every best-practice issue you notice.
|
||||
- Don’t introduce new bad practices.
|
||||
- If your change touches code that violates best practices, fix it as part of the change.
|
||||
- **Don’t tack features on.** When adding functionality, restructure logically as needed to avoid muddying interfaces and accumulating tech debt.
|
||||
|
||||
---
|
||||
|
||||
## Style and maintainability
|
||||
|
||||
### Comments and readability
|
||||
Add clear comments:
|
||||
- At logical boundaries (e.g., interfaces) so the reader doesn’t need to dig 10 layers deeper.
|
||||
- Wherever assumptions are made or something non-obvious/unexpected is done.
|
||||
- For complicated flows/functions.
|
||||
- Wherever it saves time (e.g., nontrivial regex patterns).
|
||||
|
||||
### Errors and exceptions
|
||||
- **Fail loudly** rather than silently skipping work.
|
||||
- Example: raise and let exceptions propagate instead of silently dropping a document.
|
||||
- **Don’t overuse `try/except`.**
|
||||
- Put `try/except` at the correct logical level.
|
||||
- Do not mask exceptions unless it is clearly appropriate.
|
||||
|
||||
### Typing
|
||||
- Everything should be **as strictly typed as possible**.
|
||||
- Use `cast` for annoying/loose-typed interfaces (e.g., results of `run_functions_tuples_in_parallel`).
|
||||
- Only `cast` when the type checker sees `Any` or types are too loose.
|
||||
- Prefer types that are easy to read.
|
||||
- Avoid dense types like `dict[tuple[str, str], list[list[float]]]`.
|
||||
- Prefer domain models, e.g.:
|
||||
- `EmbeddingModel(provider_name, model_name)` as a Pydantic model
|
||||
- `dict[EmbeddingModel, list[EmbeddingVector]]`
|
||||
|
||||
### State, objects, and boundaries
|
||||
- Keep **clear logical boundaries** for state containers and objects.
|
||||
- A **config** object should never contain things like a `db_session`.
|
||||
- Avoid state containers that are:
|
||||
- overly nested, or
|
||||
- huge + flat (use judgment).
|
||||
- Prefer **composition and functional style** over inheritance/OOP.
|
||||
- Prefer **no mutation** unless there’s a strong reason.
|
||||
- State objects should be **intentional and explicit**, ideally nonmutating.
|
||||
- Use interfaces/objects to create clear separation of responsibility.
|
||||
- Prefer simplicity when there’s no clear gain
|
||||
- Avoid overcomplicated mechanisms like semaphores.
|
||||
- Prefer **hash maps (dicts)** over tree structures unless there’s a strong reason.
|
||||
|
||||
### Naming
|
||||
- Name variables carefully and intentionally.
|
||||
- Prefer long, explicit names when undecided.
|
||||
- Avoid single-character variables except for small, self-contained utilities (or not at all).
|
||||
- Keep the same object/name consistent through the call stack and within functions when reasonable.
|
||||
- Good: `for token in tokens:`
|
||||
- Bad: `for msg in tokens:` (if iterating tokens)
|
||||
- Function names should bias toward **long + descriptive** for codebase search.
|
||||
- IntelliSense can miss call sites; search works best with unique names.
|
||||
- “Fetch versioned implementation” is an example of why this matters.
|
||||
|
||||
### Correctness by construction
|
||||
- Prefer self-contained correctness.
|
||||
- Don’t rely on callers to “use it right” if you can make misuse hard.
|
||||
- Avoid redundancies:
|
||||
- If a function takes an arg, it shouldn’t also take a state object that contains that same arg.
|
||||
- No dead code (unless there’s a very good reason).
|
||||
- No commented-out code in main or feature branches (unless there’s a very good reason).
|
||||
- No duplicate logic:
|
||||
- Don’t copy/paste into branches when shared logic can live above the conditional.
|
||||
- If you’re afraid to touch the original, you don’t understand it well enough.
|
||||
- LLMs often create subtle duplicate logic—review carefully and remove it.
|
||||
- Avoid “nearly identical” objects that confuse when to use which.
|
||||
- Avoid extremely long functions with chained logic:
|
||||
- Encapsulate steps into helpers for readability, even if not reused.
|
||||
- “Pythonic” multi-step expressions are OK in moderation; don’t trade clarity for cleverness.
|
||||
|
||||
---
|
||||
|
||||
## Performance and correctness
|
||||
|
||||
- Avoid holding resources for extended periods:
|
||||
- DB sessions
|
||||
- locks/semaphores
|
||||
- Validate objects:
|
||||
- on creation, and
|
||||
- right before use.
|
||||
- Connector code (data → Onyx documents):
|
||||
- Any in-memory structure that can grow without bound based on input must be periodically size-checked.
|
||||
- If a connector is OOMing (often shows up as “missing celery tasks”), this is a top thing to check retroactively.
|
||||
- Async and event loops:
|
||||
- Never introduce new async/event loop Python code, and try to make existing
|
||||
async code synchronous when possible if it makes sense.
|
||||
- Writing async code without 100% understanding the code and having a
|
||||
concrete reason to do so is likely to introduce bugs and not add any
|
||||
meaningful performance gains.
|
||||
|
||||
---
|
||||
|
||||
## Repository conventions: where code lives
|
||||
|
||||
- Pydantic + data models: `models.py` files.
|
||||
- DB interface functions (excluding lazy loading): `db/` directory.
|
||||
- LLM prompts: `prompts/` directory, roughly mirroring the code layout that uses them.
|
||||
- API routes: `server/` directory.
|
||||
|
||||
---
|
||||
|
||||
## Pydantic and modeling rules
|
||||
|
||||
- Prefer **Pydantic** over dataclasses.
|
||||
- If absolutely required, use `allow_arbitrary_types`.
|
||||
|
||||
---
|
||||
|
||||
## Data conventions
|
||||
|
||||
- Prefer explicit `None` over sentinel empty strings (usually; depends on intent).
|
||||
- Prefer explicit identifiers:
|
||||
- Use string enums instead of integer codes.
|
||||
- Avoid magic numbers (co-location is good when necessary). **Always avoid magic strings.**
|
||||
|
||||
---
|
||||
|
||||
## Logging
|
||||
|
||||
- Log messages where they are created.
|
||||
- Don’t propagate log messages around just to log them elsewhere.
|
||||
|
||||
---
|
||||
|
||||
## Encapsulation
|
||||
|
||||
- Don’t use private attributes/methods/properties from other classes/modules.
|
||||
- “Private” is private—respect that boundary.
|
||||
|
||||
---
|
||||
|
||||
## SQLAlchemy guidance
|
||||
|
||||
- Lazy loading is often bad at scale, especially across multiple list relationships.
|
||||
- Be careful when accessing SQLAlchemy object attributes:
|
||||
- It can help avoid redundant DB queries,
|
||||
- but it can also fail if accessed outside an active session,
|
||||
- and lazy loading can add hidden DB dependencies to otherwise “simple” functions.
|
||||
- Reference: https://www.reddit.com/r/SQLAlchemy/comments/138f248/joinedload_vs_selectinload/
|
||||
|
||||
---
|
||||
|
||||
## Trunk-based development and feature flags
|
||||
|
||||
- **PRs should contain no more than 500 lines of real change.**
|
||||
- **Merge to main frequently.** Avoid long-lived feature branches—they create merge conflicts and integration pain.
|
||||
- **Use feature flags for incremental rollout.**
|
||||
- Large features should be merged in small, shippable increments behind a flag.
|
||||
- This allows continuous integration without exposing incomplete functionality.
|
||||
- **Keep flags short-lived.** Once a feature is fully rolled out, remove the flag and dead code paths promptly.
|
||||
- **Flag at the right level.** Prefer flagging at API/UI entry points rather than deep in business logic.
|
||||
- **Test both flag states.** Ensure the codebase works correctly with the flag on and off.
|
||||
|
||||
---
|
||||
|
||||
## Misc
|
||||
|
||||
- Any TODOs you add in the code must be accompanied by either the name/username
|
||||
of the owner of that TODO, or an issue number for an issue referencing that
|
||||
piece of work.
|
||||
- Avoid module-level logic that runs on import, which leads to import-time side
|
||||
effects. Essentially every piece of meaningful logic should exist within some
|
||||
function that has to be explicitly invoked. Acceptable exceptions to this may
|
||||
include loading environment variables or setting up loggers.
|
||||
- If you find yourself needing something like this, you may want that logic to
|
||||
exist in a file dedicated for manual execution (contains `if __name__ ==
|
||||
"__main__":`) which should not be imported by anything else.
|
||||
- Related to the above, do not conflate Python scripts you intend to run from
|
||||
the command line (contains `if __name__ == "__main__":`) with modules you
|
||||
intend to import from elsewhere. If for some unlikely reason they have to be
|
||||
the same file, any logic specific to executing the file (including imports)
|
||||
should be contained in the `if __name__ == "__main__":` block.
|
||||
- Generally these executable files exist in `backend/scripts/`.
|
||||
@@ -1,36 +0,0 @@
|
||||
## Some additional notes for Mac Users
|
||||
|
||||
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Setting up Python
|
||||
|
||||
Ensure [Homebrew](https://brew.sh/) is already set up.
|
||||
|
||||
Then install python 3.11.
|
||||
|
||||
```bash
|
||||
brew install python@3.11
|
||||
```
|
||||
|
||||
Add python 3.11 to your path: add the following line to ~/.zshrc
|
||||
|
||||
```
|
||||
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> You will need to open a new terminal for the path change above to take effect.
|
||||
|
||||
### Setting up Docker
|
||||
|
||||
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
|
||||
ensure it is running before continuing with the docker commands.
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly.
|
||||
After installing pre-commit, run the following command:
|
||||
|
||||
```bash
|
||||
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
```
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user